Skip to content

External data related issue when quantizing a large ONNX model (6GB) to fp8 #458

@jianlany

Description

@jianlany

Describe the bug

Modelopt doesn't quantize my onnx model to fp8 successfully, raising the error:
onnx.onnx_cpp2py_export.checker.ValidationError: Data of TensorProto ( tensor name: onnx::MatMul_8237_scale) should be stored in model_f931b22a.onnx_data, but it doesn't exist or is not accessible.

But it worked fine if I quantized a smaller model without external data. Quantizing the large model to int8 also works.

Comparison between the log files shows that the working quantization run (with small model) does not print the warning message from graphsanitizer.py, which is highlighted in the attached full log.

Steps/Code to reproduce bug

The original model is too large to be attached (6GB),

moq.quantize(
    onnx_path=onnx_in,
    quantize_mode="fp8",
    calibration_data=calib,                            # dict[str, ndarray]
    calibration_method="entropy",                      # or "max"
    calibration_eps=["cpu"],                           # calibrate on CPU for determinism
    # Good first pass: Transformery-safe allowlist
    op_types_to_quantize=["MatMul", "Gemm"],           # start conservative; you can widen later
    use_external_data_format=True,                     # writes .onnx + .onnx_data
    dq_only=False,                                     # emit full Q/DQ, not DQ-only
    # Optional: pin shapes (helps TRT later if your model is dynamic)
    calibration_shapes="input_ids:8x256,attention_mask:8x256",
    output_path=onnx_out,
)

Full Log:

Loading feature dump...
Loaded 20 calibration samples
Created HF dataset
Stacking calibration data...
Calibration data shape - ids: (20, 256), mask: (20, 256)
Starting quantization...
Input model: model.onnx
Output model: model_fp8_qdq.onnx
[modelopt][onnx] - INFO - Starting quantization process for model: model.onnx
[modelopt][onnx] - INFO - Quantization mode: fp8
[modelopt][onnx] - INFO - Preprocessing the model model.onnx
[modelopt][onnx] - INFO - Model has dynamic inputs: ['input_ids', 'attention_mask']
[modelopt][onnx] - INFO - No custom ops found. If that's not correct, please make sure that the 'tensorrt' python package is correctly installed and that the paths to 'libcudnn*.so' and TensorRT 'lib/' are in 'LD_LIBRARY_PATH'. If the custom op is not directly available as a plugin in TensorRT, please also make sure that the path to the compiled '.so' TensorRT plugin is also being given via the '--trt_plugins' flag (requires TRT 10+).
[modelopt][onnx] - INFO - Duplicating shared constants
[modelopt][onnx] - DEBUG - Updating TRT EP support - DDS ops: False, Custom ops: False
[modelopt][onnx] - INFO - Setting up CalibrationDataProvider for calibration
[modelopt][onnx] - DEBUG - Getting input shapes from model
[modelopt][onnx] - DEBUG - Multi-tensor calibration data with 2 inputs
[modelopt][onnx] - DEBUG - Creating 2 calibration iterations
[modelopt][onnx] - INFO - Analyzing MHA nodes for fp8 quantization
[modelopt][onnx] - DEBUG - Model size: 6176683346 bytes, using external data: True
[modelopt][onnx] - WARNING - Removing existing external data file: model_extended.onnx_data
[modelopt][onnx] - INFO - Creating ORT InferenceSession
[modelopt][onnx] - DEBUG - Preparing execution providers list from: ['cpu']
[modelopt][onnx] - DEBUG - Added CPU EP
[modelopt][onnx] - INFO - Successfully enabled 1 EPs for ORT: ['CPUExecutionProvider']
[modelopt][onnx] - DEBUG - Creating session with providers: ['CPUExecutionProvider']
[modelopt][onnx] - DEBUG - Matmul nodes From MHA to exclude: []
[modelopt][onnx] - INFO - Starting FP8 quantization process
[modelopt][onnx] - INFO - Loading ONNX model from model.onnx
[modelopt][onnx] - DEBUG - Model size: 6176682096 bytes, using external data: True
[modelopt][onnx] - INFO - Detecting GEMV patterns for TRT optimization
[modelopt][onnx] - DEBUG - Found 254 MatMul nodes to analyze
[modelopt][onnx] - DEBUG - Matmul nodes to exclude: []
[modelopt][onnx] - DEBUG - Excluding 0 MatMul nodes due to GEMV pattern
[modelopt][onnx] - INFO - Scanning for unsupported Conv nodes for quantization
[modelopt][onnx] - INFO - Found 0 unsupported Conv nodes for quantization
[modelopt][onnx] - INFO - Configuring ORT for ModelOpt ONNX quantization
[modelopt][onnx] - DEBUG - Registering custom QDQ operators
[modelopt][onnx] - DEBUG - Patching ORT modules
[modelopt][onnx] - DEBUG - Removing non-quantizable ops from QDQ registry
[modelopt][onnx] - DEBUG - Preparing execution providers list from: ['cpu']
[modelopt][onnx] - DEBUG - Added CPU EP
[modelopt][onnx] - INFO - Successfully enabled 1 EPs for ORT: ['CPUExecutionProvider']
[modelopt][onnx] - DEBUG - Getting quantizable operator types
[modelopt][onnx] - INFO - Quantizable op types in the model: ['MatMul', 'Add']
[modelopt][onnx] - INFO - Finding nodes to quantize
[modelopt][onnx] - INFO - Building non-residual Add input map
[modelopt][onnx] - INFO - Searching for patterns like MHA, LayerNorm, etc
[modelopt][onnx] - INFO - Found 0 layer norm partitions
[modelopt][onnx] - INFO - Found 28 MHA (QK_AV) Patterns
[modelopt][onnx] - INFO - Found 29 non-quantizable partitions
[modelopt][onnx] - INFO - Building KGEN/CASK targeted partitions
[modelopt][onnx] - INFO - Classifying partition nodes
[modelopt][onnx] - INFO - Found 198 quantizable partition nodes and 169 quantizable KGEN heads
[modelopt][onnx] - INFO - Finding quantizable nodes. Initial nodes to quantize: 367
[modelopt][onnx] - INFO - Found 0 pooling/window ops
[modelopt][onnx] - INFO - Total number of quantizable nodes: 367
[modelopt][onnx] - DEBUG - Selected nodes to quantize: ['/model/layers.0/self_attn/Add', '/model/layers.0/self_attn/Add', '/model/layers.0/self_attn/Add_1', '/model/layers.0/self_attn/Add_1', '/model/layers.0/Add', '/model/layers.0/Add_1', '/model/layers.1/self_attn/Add', '/model/layers.1/self_attn/Add', '/model/layers.1/self_attn/Add_1', '/model/layers.1/self_attn/Add_1', '/model/layers.1/Add', '/model/layers.1/Add_1', '/model/layers.2/self_attn/Add', '/model/layers.2/self_attn/Add', '/model/layers.2/self_attn/Add_1', '/model/layers.2/self_attn/Add_1', '/model/layers.2/Add', '/model/layers.2/Add_1', '/model/layers.3/self_attn/Add', '/model/layers.3/self_attn/Add', '/model/layers.3/self_attn/Add_1', '/model/layers.3/self_attn/Add_1', '/model/layers.3/Add', '/model/layers.3/Add_1', '/model/layers.4/self_attn/Add', '/model/layers.4/self_attn/Add', '/model/layers.4/self_attn/Add_1', '/model/layers.4/self_attn/Add_1', '/model/layers.4/Add', '/model/layers.4/Add_1', '/model/layers.5/self_attn/Add', '/model/layers.5/self_attn/Add', '/model/layers.5/self_attn/Add_1', '/model/layers.5/self_attn/Add_1', '/model/layers.5/Add', '/model/layers.5/Add_1', '/model/layers.6/self_attn/Add', '/model/layers.6/self_attn/Add', '/model/layers.6/self_attn/Add_1', '/model/layers.6/self_attn/Add_1', '/model/layers.6/Add', '/model/layers.6/Add_1', '/model/layers.7/self_attn/Add', '/model/layers.7/self_attn/Add', '/model/layers.7/self_attn/Add_1', '/model/layers.7/self_attn/Add_1', '/model/layers.7/Add', '/model/layers.7/Add_1', '/model/layers.8/self_attn/Add', '/model/layers.8/self_attn/Add', '/model/layers.8/self_attn/Add_1', '/model/layers.8/self_attn/Add_1', '/model/layers.8/Add', '/model/layers.8/Add_1', '/model/layers.9/self_attn/Add', '/model/layers.9/self_attn/Add', '/model/layers.9/self_attn/Add_1', '/model/layers.9/self_attn/Add_1', '/model/layers.9/Add', '/model/layers.9/Add_1', '/model/layers.10/self_attn/Add', '/model/layers.10/self_attn/Add', '/model/layers.10/self_attn/Add_1', '/model/layers.10/self_attn/Add_1', '/model/layers.10/Add', '/model/layers.10/Add_1', '/model/layers.11/self_attn/Add', '/model/layers.11/self_attn/Add', '/model/layers.11/self_attn/Add_1', '/model/layers.11/self_attn/Add_1', '/model/layers.11/Add', '/model/layers.11/Add_1', '/model/layers.12/self_attn/Add', '/model/layers.12/self_attn/Add', '/model/layers.12/self_attn/Add_1', '/model/layers.12/self_attn/Add_1', '/model/layers.12/Add', '/model/layers.12/Add_1', '/model/layers.13/self_attn/Add', '/model/layers.13/self_attn/Add', '/model/layers.13/self_attn/Add_1', '/model/layers.13/self_attn/Add_1', '/model/layers.13/Add', '/model/layers.13/Add_1', '/model/layers.14/self_attn/Add', '/model/layers.14/self_attn/Add', '/model/layers.14/self_attn/Add_1', '/model/layers.14/self_attn/Add_1', '/model/layers.14/Add', '/model/layers.14/Add_1', '/model/layers.15/self_attn/Add', '/model/layers.15/self_attn/Add', '/model/layers.15/self_attn/Add_1',
[modelopt][onnx] - INFO - Finding concat eliminated tensors
[modelopt][onnx] - DEBUG - Created temporary file for intermediate model: /tmp/tmpvmpcamb9.onnx
[modelopt][onnx] - INFO - Starting INT8 quantization with 'entropy' calibration
[modelopt][onnx] - INFO - Starting static quantization
[modelopt][onnx] - DEBUG - Quantization format: QDQ
[modelopt][onnx] - DEBUG - Activation type: QInt8
[modelopt][onnx] - DEBUG - Weight type: QInt8
[modelopt][onnx] - DEBUG - Calibration method: CalibrationMethod.Entropy
[modelopt][onnx] - DEBUG - Model size: 6176682096 bytes, using external data: True
[modelopt][onnx] - DEBUG - Calibration extra options: {'trt_extra_plugin_lib_paths': None, 'execution_providers': ['CPUExecutionProvider']}
[modelopt][onnx] - DEBUG - Creating calibrator
[modelopt][onnx] - DEBUG - Model size: 6176682096 bytes, using external data: True
[modelopt][onnx] - DEBUG - Creating inference session with Execution Provider configuration
[modelopt][onnx] - DEBUG - Execution providers: ['CPUExecutionProvider']
[modelopt][onnx] - DEBUG - Collecting calibration data
Collecting tensor data and making histogram ...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 593/593 [00:21<00:00, 28.19it/s]
Collecting tensor data and making histogram ...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 593/593 [00:20<00:00, 28.36it/s]
[modelopt][onnx] - DEBUG - Computing tensor ranges
Finding optimal threshold for each tensor using 'entropy' algorithm ...
Number of tensors : 593
Number of histogram bins : 128 (The number may increase depends on the data it collects)
Number of quantized bins : 128
[modelopt][onnx] - INFO - Starting post-processing of quantized model
[modelopt][onnx] - INFO - Deleting QDQ nodes from marked inputs to make certain operations fusible
[modelopt][onnx] - INFO - Converting float tensors to fp16
2025-10-22 07:01:03,611 - autocast - WARNING - graphsanitizer.py - Failed to convert model to opset 21: /home/task_176034957725121/conda-bld/onnx_1760349588539/work/onnx/version_converter/convert.h:89: assertInVersionRange: Assertion version >= version_range.first && version <= version_range.second failed: Warning: invalid version (must be between 1 and 24)
2025-10-22 07:01:03,611 - autocast - WARNING - graphsanitizer.py - Attempting to continue with the original opsets: [19]
[modelopt][onnx] - DEBUG - Model size: 6180275518 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 6180377492 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 3092713663 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 3093259088 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 3092711853 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 3092711906 bytes, using external data: True
[modelopt][onnx] - DEBUG - Model size: 3092711906 bytes, using external data: True

/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/autocast/precisionconverter.py:991: RuntimeWarning: overflow encountered in cast
casted_data = original_data.astype(cast_dtype)
[modelopt][onnx] - DEBUG - Model size: 3092136553 bytes, using external data: True
[modelopt][onnx] - INFO - Starting INT8 to FP8 conversion
Traceback (most recent call last):
File "/data_nvme/jianlanye/tasks/l3slm_1.5b_fp8_quantization_1020/fp8_quantization/quantize_opt.py", line 79, in
moq.quantize(
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/quantization/quantize.py", line 474, in quantize
onnx_model = quantize_func(
^^^^^^^^^^^^^^
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/quantization/fp8.py", line 341, in quantize
onnx_model = int8_to_fp8(onnx_model)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/quantization/fp8.py", line 124, in int8_to_fp8
_convert(node)
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/quantization/fp8.py", line 89, in _convert
fp8_scale = _int8_scale_to_fp8_scale(scale, scale_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/modelopt/onnx/quantization/fp8.py", line 71, in _int8_scale_to_fp8_scale
np_scale = onnx.numpy_helper.to_array(scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/onnx/numpy_helper.py", line 414, in to_array
onnx.external_data_helper.load_external_data_for_tensor(tensor, base_dir)
File "/home/jianlanye/miniconda3/envs/nv_modelopt_quantize/lib/python3.11/site-packages/onnx/external_data_helper.py", line 53, in load_external_data_for_tensor
external_data_file_path = c_checker._resolve_external_data_location( # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnx.onnx_cpp2py_export.checker.ValidationError: Data of TensorProto ( tensor name: onnx::MatMul_8237_scale) should be stored in model_f931b22a.onnx_data, but it doesn't exist or is not accessible.

Expected behavior

Export the quantized fp8 ONNX model without problems

System information

======================================================================

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 22.04.5 LTS
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): NVIDIA H100 NVL
  • GPU memory size: 93.6 GB
  • Number of GPUs: 1
  • Library versions (if applicable):
    • Python: 3.11.14
    • ModelOpt version or commit hash: 0.37.0
    • CUDA: ?
    • PyTorch: 2.9.0+cu128
    • Transformers: ?
    • TensorRT-LLM: ?
    • ONNXRuntime: 1.23.0
    • TensorRT: ?
  • Any other details that may help: ?
    ======================================================================
Click to expand: Python script to automatically collect system information
import platform
import re
import subprocess


def get_nvidia_gpu_info():
    try:
        nvidia_smi = (
            subprocess.check_output(
                "nvidia-smi --query-gpu=name,memory.total,count --format=csv,noheader,nounits",
                shell=True,
            )
            .decode("utf-8")
            .strip()
            .split("\n")
        )
        if len(nvidia_smi) > 0:
            gpu_name = nvidia_smi[0].split(",")[0].strip()
            gpu_memory = round(float(nvidia_smi[0].split(",")[1].strip()) / 1024, 1)
            gpu_count = len(nvidia_smi)
            return gpu_name, f"{gpu_memory} GB", gpu_count
    except Exception:
        return "?", "?", "?"


def get_cuda_version():
    try:
        nvcc_output = subprocess.check_output("nvcc --version", shell=True).decode("utf-8")
        match = re.search(r"release (\d+\.\d+)", nvcc_output)
        if match:
            return match.group(1)
    except Exception:
        return "?"


def get_package_version(package):
    try:
        return getattr(__import__(package), "__version__", "?")
    except Exception:
        return "?"


# Get system info
os_info = f"{platform.system()} {platform.release()}"
if platform.system() == "Linux":
    try:
        os_info = (
            subprocess.check_output("cat /etc/os-release | grep PRETTY_NAME | cut -d= -f2", shell=True)
            .decode("utf-8")
            .strip()
            .strip('"')
        )
    except Exception:
        pass
elif platform.system() == "Windows":
    print("Please add the `windows` label to the issue.")

cpu_arch = platform.machine()
gpu_name, gpu_memory, gpu_count = get_nvidia_gpu_info()
cuda_version = get_cuda_version()

# Print system information in the format required for the issue template
print("=" * 70)
print("- Container used (if applicable): " + "?")
print("- OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): " + os_info)
print("- CPU architecture (x86_64, aarch64): " + cpu_arch)
print("- GPU name (e.g. H100, A100, L40S): " + gpu_name)
print("- GPU memory size: " + gpu_memory)
print("- Number of GPUs: " + str(gpu_count))
print("- Library versions (if applicable):")
print("  - Python: " + platform.python_version())
print("  - ModelOpt version or commit hash: " + get_package_version("modelopt"))
print("  - CUDA: " + cuda_version)
print("  - PyTorch: " + get_package_version("torch"))
print("  - Transformers: " + get_package_version("transformers"))
print("  - TensorRT-LLM: " + get_package_version("tensorrt_llm"))
print("  - ONNXRuntime: " + get_package_version("onnxruntime"))
print("  - TensorRT: " + get_package_version("tensorrt"))
print("- Any other details that may help: " + "?")
print("=" * 70)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions