Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Model Optimizer Changelog (Linux)

**Deprecations**

- TRT-LLM's TRT backend in ``examples/llm_ptq`` and ``examples/vlm_ptq``.
- ``--export_fmt`` flag in ``examples/llm_ptq`` is removed. By default we export to the unified Hugging Face checkpoint format.
- ``examples/vlm_eval`` as it depends on the deprecated TRT-LLM's TRT backend.

**Bug Fixes**

**New Features**
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6
FROM nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2

ARG PIP_EXTRA_INDEX_URL="https://pypi.nvidia.com"
ENV PIP_EXTRA_INDEX_URL=$PIP_EXTRA_INDEX_URL \
Expand Down
96 changes: 39 additions & 57 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,20 @@ def auto_quantize(
qformat_list = qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
if args.export_fmt == "hf":
assert all(
qformat
in [
"fp8",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
for qformat in qformat_list
), (
"One or more quantization formats provided are not supported for unified checkpoint export"
)
else:
assert all(
qformat in ["fp8", "int8_sq", "int4_awq", "w4a8_awq", "nvfp4", "nvfp4_awq"]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for tensorrt llm export"
assert all(
qformat
in [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we can pull this list of supported qformats as a variable and reuse in other places (in the auto quantize section)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK. This PR is pretty large. I hope we can move the improvements to a following up

"fp8",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
Expand Down Expand Up @@ -219,27 +211,21 @@ def main(args):
"Quantization supports only one quantization format."
)

# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
if args.export_fmt == "hf":
assert args.sparsity_fmt == "dense", (
f"Sparsity format {args.sparsity_fmt} not supported by unified export api."
)

if not args.auto_quantize_bits:
assert (
args.qformat
in [
"int4_awq",
"fp8",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Quantization format {args.qformat} not supported for HF export path"
if not args.auto_quantize_bits:
assert (
args.qformat
in [
"int4_awq",
"fp8",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Quantization format {args.qformat} not supported for HF export path"

Comment on lines +214 to 229
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Assertion is ineffective; or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES is always True.

This currently lets unsupported qformats through. Tighten the check.

-    if not args.auto_quantize_bits:
-        assert (
-            args.qformat
-            in [
-                "int4_awq",
-                "fp8",
-                "nvfp4",
-                "nvfp4_awq",
-                "w4a8_awq",
-                "fp8_pb_wo",
-                "w4a8_mxfp4_fp8",
-                "nvfp4_mlp_only",
-            ]
-            or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
-        ), f"Quantization format {args.qformat} not supported for HF export path"
+    if not args.auto_quantize_bits:
+        assert args.qformat in ALLOWED_UNIFIED_HF_QFORMATS, (
+            f"Quantization format {args.qformat} not supported for HF export path"
+        )

Note: If you intended to allow “KV-cache-only” quant, handle that as a separate branch instead of weakening this assert.

Committable suggestion skipped: line range outside the PR's diff.

# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
calibration_only = False
Expand All @@ -253,9 +239,6 @@ def main(args):
attn_implementation=args.attn_implementation,
)
else:
assert args.export_fmt == "hf", (
"Low memory mode is only supported for exporting HF checkpoint."
)
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
)
Expand Down Expand Up @@ -600,7 +583,10 @@ def output_decode(generated_ids, input_shape):
setattr(model.config, "architectures", full_model_config.architectures)

start_time = time.time()
if args.export_fmt == "tensorrt_llm":
if model_type in ["t5", "bart", "whisper"] or args.sparsity_fmt != "dense":
# Still export TensorRT-LLM checkpoints for the models not supported by the
# TensorRT-LLM torch runtime.

# Move meta tensor back to device before exporting.
remove_hook_from_module(model, recurse=True)

Expand All @@ -621,13 +607,16 @@ def output_decode(generated_ids, input_shape):
inference_tensor_parallel=args.inference_tensor_parallel,
inference_pipeline_parallel=args.inference_pipeline_parallel,
)
elif args.export_fmt == "hf":
else:
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
assert args.sparsity_fmt == "dense", (
f"Sparsity format {args.sparsity_fmt} not supported by unified export api."
)

export_hf_checkpoint(
full_model,
export_dir=export_path,
)
else:
raise NotImplementedError(f"{args.export_fmt} not supported")

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
Expand Down Expand Up @@ -707,13 +696,6 @@ def output_decode(generated_ids, input_shape):
choices=KV_QUANT_CFG_CHOICES.keys(),
help="Specify KV cache quantization format, default to fp8 if not provided",
)
parser.add_argument(
"--export_fmt",
required=False,
default="tensorrt_llm",
choices=["tensorrt_llm", "hf"],
help=("Checkpoint export format"),
)
parser.add_argument(
"--trust_remote_code",
help="Set trust_remote_code for Huggingface models and tokenizers",
Expand Down
Loading
Loading