Skip to content

[Bug]: modelopt max_calibrate AssertionError when running quantize.py (fp8/nvfp4) on Blackwell RTX PRO 6000 with NGC 1.2.0rc3 #9402

@coga-ash

Description

@coga-ash

System Info

  • CPU architecture: x86_64
  • CPU: Intel Xeon Silver 4514Y (16 cores / 32 threads)
  • Host memory: large enough to host multiple 70B models in parallel (exact size can be provided if needed)
  • GPU properties:
    • GPU name: 4 × NVIDIA RTX PRO 6000 Blackwell
    • GPU memory: 96 GB VRAM per GPU
  • TensorRT-LLM:
    • Version (from logs): 1.2.0rc3 ([TensorRT-LLM] TensorRT LLM version: 1.2.0rc3)
    • Container image: nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc3
  • Libraries inside the NGC container:
    • PyTorch: 2.9.0a0+145a3a7 (printed at container startup)
    • ModelOpt: version bundled in the 1.2.0rc3 container (used via tensorrt_llm.quantization.quantize_by_modelopt)
  • CUDA / driver:
    • Host NVIDIA driver: 580.95.05
    • Host CUDA version: 13.0 (nvidia-smi shows “CUDA Version: 13.0”)
  • OS:
    • Host OS: Ubuntu 24.04 (Noble)
    • Running the NGC container via Docker on the host
  • GPUs are all visible and healthy from the host:
    • nvidia-smi shows 4× RTX PRO 6000 Blackwell with ~98 GB total memory each, driver 580.95.05

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce the behavior:

  1. Pull the official NGC TensorRT-LLM container:
   docker pull nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc3
On the host, download and store HuggingFace models under ~/hf_model.
In my case, I have:

~/hf_model/A.X-4.0-Light-7B (Qwen2-based model)

~/hf_model/Midm-2.0-Base-Instruct-11B (LLaMA-based model)

~/trt_model/trt_ckpt and ~/trt_model/trt_engine as output folders.

Run the NGC container with all GPUs and mount the model folders:

docker run --rm -it --gpus all --ipc=host \
  --ulimit memlock=-1 --ulimit stack=67108864 \
  -v $HOME/hf_model:/workspace/hf_model \
  -v $HOME/trt_model:/workspace/trt_model \
  nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc3 \

Inside the container, go to the quantization example:

코드 복사
cd /app/tensorrt_llm/examples/quantization
Run quantize.py using either fp8 or nvfp4 qformat.
Both of the following commands fail with the same type of assertion error:

(a) fp8 quantization of A.X-4.0-Light-7B (Qwen2-based)

코드 복사
python3 quantize.py \
  --model_dir /workspace/hf_model/A.X-4.0-Light-7B \
  --qformat fp8 \
  --kv_cache_dtype fp8 \
  --output_dir /workspace/trt_model/trt_ckpt/A.X-4.0-Light-7B-fp8

(b) nvfp4 quantization of the same model

python3 quantize.py \
  --model_dir /workspace/hf_model/A.X-4.0-Light-7B \
  --qformat nvfp4 \
  --kv_cache_dtype fp8 \
  --output_dir /workspace/trt_model/trt_ckpt/A.X-4.0-Light-7B-nvfp4

(c) Similar failure with a LLaMA-based model (Midm-2.0-Base-Instruct-11B)

python3 quantize.py \
  --model_dir /workspace/hf_model/Midm-2.0-Base-Instruct-11B \
  --qformat nvfp4 \
  --kv_cache_dtype fp8 \
  --tp_size 4 \
  --output_dir /workspace/trt_model/trt_ckpt/Midm-2.0-Base-Instruct-11B-nvfp4
The error occurs regardless of whether I specify --tp_size 4 or leave the default (so it doesn’t seem to depend on tensor parallelism).

error msg

[TensorRT-LLM] TensorRT LLM version: 1.2.0rc3
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.35it/s]
Registered <class 'transformers.models.qwen2.modeling_qwen2.Qwen2Attention'> to _QuantAttention for KV Cache quantization
Inserted 675 quantizers
Traceback (most recent call last):
  File "/app/tensorrt_llm/examples/quantization/quantize.py", line 160, in <module>
    quantize_and_export(
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/quantize_by_modelopt.py", line 788, in quantize_and_export
    model = quantize_model(model, quant_cfg, calib_dataloader, batch_size,
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/quantize_by_modelopt.py", line 575, in quantize_model
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/model_quant.py", line 226, in quantize
    return calibrate(model, config["algorithm"], forward_loop=forward_loop)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/model_quant.py", line 102, in calibrate
    apply_mode(
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/conversion.py", line 418, in apply_mode
    model, metadata = get_mode(m).convert(model, config, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/mode.py", line 293, in wrapped_func
    return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/mode.py", line 219, in wrapped_calib_func
    func(model, forward_loop=forward_loop, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/model_calib.py", line 76, in max_calibrate
    forward_loop(model)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/model_quant.py", line 95, in forward_loop
    return original_forward_loop()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/quantize_by_modelopt.py", line 487, in calibrate_loop
    model(**data)
  ...
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/nn/modules/tensor_quantizer.py", line 953, in forward
    self.collect(inputs)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/nn/modules/tensor_quantizer.py", line 1137, in collect
    self._calibrator.collect(inputs)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/quantization/calib/max.py", line 69, in collect
    assert torch.all(local_amax >= 0), (
AssertionError: detected negative values after abs, could be torch or cuda bug

Expected behavior

I expected quantize.py to successfully complete post-training quantization (either fp8 or nvfp4) for these HuggingFace models and export a valid TensorRT-LLM checkpoint under trt_ckpt, or at least fail with a more conventional error (e.g., GPU OOM).

In particular, I expected that the ModelOpt max calibration step would not hit an internal assertion like “detected negative values after abs” or “detected inf values in amax”, since those suggest an internal numerical issue rather than a user configuration error.

actual behavior

Instead of completing, quantize.py aborts during the ModelOpt max calibration step with an AssertionError:

  • AssertionError: detected negative values after abs, could be torch or cuda bug
  • or AssertionError: detected inf values in amax. inf in original tensor: True

This happens:

  • With qformat=fp8 and qformat=nvfp4
  • With Qwen2-based (A.X-4.0-Light-7B) and LLaMA-based (Midm-2.0-Base-Instruct-11B, A.X-4.0-70B) models
  • Whether I use --tp_size 4 or leave the default tensor parallel size (so it does not seem to depend on multi-GPU configuration)
  • Inside a clean NGC TensorRT-LLM container (release:1.2.0rc3) on a system with 4× RTX PRO 6000 Blackwell GPUs

Because of the assertion, no trt_ckpt output is produced for these qformats.

additional notes

  • The same models can be quantized to TensorRT-LLM format on this machine using the separate TRT Model Optimizer (hf_ptq.py, exporting with --export_fmt tensorrt_llm) outside of this NGC container. In that flow, I only see expected issues like GPU OOM if I push the limits too far, but I do not see the ModelOpt max_calibrate assertion.
  • This makes me suspect that the problem may be specific to:
    • the combination of TensorRT-LLM 1.2.0rc3 + bundled ModelOpt version inside the NGC container, and/or
    • running on Blackwell-generation GPUs (RTX PRO 6000 Blackwell, CUDA 13.0).
  • The error appears deterministic: it reproduces every time with the commands above on a clean container.
  • I am happy to:
    • run additional experiments with other NGC tags (e.g. 1.2.0rc2) if that helps narrow down a regression,
    • rerun with export TLLM_DEBUG_MODE=1 and attach additional logs,
    • or test potential patches/workarounds on this Blackwell system.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Labels

Model optimization<NV>Model-specific performance optimizations and tuningbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions