Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Model Optimizer Changelog (Linux)
**Deprecations**
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.

- TRT-LLM's TRT backend in ``examples/llm_ptq`` and ``examples/vlm_ptq``. Tasks ``build`` and ``benchmark`` support are removed and replaced with ``quant``. For performance evaluation, please use ``trtllm-bench`` directly.
- ``--export_fmt`` flag in ``examples/llm_ptq`` is removed. By default we export to the unified Hugging Face checkpoint format.
- ``examples/vlm_eval`` as it depends on the deprecated TRT-LLM's TRT backend.

**Bug Fixes**

**New Features**
Expand Down
6 changes: 3 additions & 3 deletions examples/llm_eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ If `trust_remote_code` needs to be true, please append the command with the `--t
### TensorRT-LLM

```sh
python lm_eval_tensorrt_llm.py --model trt-llm --model_args tokenizer=<HF model folder>,engine_dir=<TRT LLM engine dir> --tasks <comma separated tasks> --batch_size <engine batch size>
python lm_eval_tensorrt_llm.py --model trt-llm --model_args tokenizer=<HF model folder>,engine_dir=<Quantized checkpoint dir> --tasks <comma separated tasks> --batch_size <engine batch size>
```

## MMLU
Expand Down Expand Up @@ -140,7 +140,7 @@ python mmlu.py --model_name causal --model_path <HF model folder or model card>
### Evaluate the TensorRT-LLM engine

```bash
python mmlu.py --model_name causal --model_path <HF model folder or model card> --engine_dir <built TensorRT-LLM folder>
python mmlu.py --model_name causal --model_path <HF model folder or model card> --engine_dir <Quantized checkpoint dir>
```

## MT-Bench
Expand All @@ -163,7 +163,7 @@ bash run_fastchat.sh -h <HF model folder or model card> --quant_cfg MODELOPT_QUA
### Evaluate the TensorRT-LLM engine

```bash
bash run_fastchat.sh -h <HF model folder or model card> <built TensorRT-LLM folder>
bash run_fastchat.sh -h <HF model folder or model card> <Quantized checkpoint dir>
```

### Judging the responses
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ scripts/huggingface_example.sh --type llama --model $HF_PATH --quant w4a8_awq,fp
The above example perform `AutoQuantize` where the less quantization accuracy sensitive layers are quantized with `w4a8_awq` (specified by `--quant w4a8_awq`) and the more sensitive layers
are kept un-quantized such that the effective bits is 4.8 (specified by `--auto_quantize_bits 4.8`).

The example scripts above also have an additional flag `--tasks`, where the actual tasks run in the script can be customized. The allowed tasks are `build,mmlu,benchmark,lm_eval,livecodebench` specified in the script [parser](./scripts/parser.sh). The tasks combo can be specified with a comma-separated task list. Some tasks like mmlu can take a long time to run. To run lm_eval tasks, please also specify the `--lm_eval_tasks` flag with comma separated lm_eval tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks).
The example scripts above also have an additional flag `--tasks`, where the actual tasks run in the script can be customized. The allowed tasks are `quant,mmlu,lm_eval,livecodebench` specified in the script [parser](./scripts/parser.sh). The tasks combo can be specified with a comma-separated task list. Some tasks like mmlu can take a long time to run. To run lm_eval tasks, please also specify the `--lm_eval_tasks` flag with comma separated lm_eval tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks).

> *If GPU out-of-memory error is reported running the scripts, please try editing the scripts and reducing the max batch size to save GPU memory.*

Expand Down
107 changes: 45 additions & 62 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,20 @@ def auto_quantize(
qformat_list = qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
if args.export_fmt == "hf":
assert all(
qformat
in [
"fp8",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
for qformat in qformat_list
), (
"One or more quantization formats provided are not supported for unified checkpoint export"
)
else:
assert all(
qformat in ["fp8", "int8_sq", "int4_awq", "w4a8_awq", "nvfp4", "nvfp4_awq"]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for tensorrt llm export"
assert all(
qformat
in [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we can pull this list of supported qformats as a variable and reuse in other places (in the auto quantize section)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK. This PR is pretty large. I hope we can move the improvements to a following up

"fp8",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
Expand Down Expand Up @@ -219,27 +211,21 @@ def main(args):
"Quantization supports only one quantization format."
)

# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
if args.export_fmt == "hf":
assert args.sparsity_fmt == "dense", (
f"Sparsity format {args.sparsity_fmt} not supported by unified export api."
)

if not args.auto_quantize_bits:
assert (
args.qformat
in [
"int4_awq",
"fp8",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Quantization format {args.qformat} not supported for HF export path"
if not args.auto_quantize_bits:
assert (
args.qformat
in [
"int4_awq",
"fp8",
"nvfp4",
"nvfp4_awq",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Quantization format {args.qformat} not supported for HF export path"

Comment on lines +214 to 229
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Assertion is ineffective; or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES is always True.

This currently lets unsupported qformats through. Tighten the check.

-    if not args.auto_quantize_bits:
-        assert (
-            args.qformat
-            in [
-                "int4_awq",
-                "fp8",
-                "nvfp4",
-                "nvfp4_awq",
-                "w4a8_awq",
-                "fp8_pb_wo",
-                "w4a8_mxfp4_fp8",
-                "nvfp4_mlp_only",
-            ]
-            or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
-        ), f"Quantization format {args.qformat} not supported for HF export path"
+    if not args.auto_quantize_bits:
+        assert args.qformat in ALLOWED_UNIFIED_HF_QFORMATS, (
+            f"Quantization format {args.qformat} not supported for HF export path"
+        )

Note: If you intended to allow “KV-cache-only” quant, handle that as a separate branch instead of weakening this assert.

Committable suggestion skipped: line range outside the PR's diff.

# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
calibration_only = False
Expand All @@ -253,9 +239,6 @@ def main(args):
attn_implementation=args.attn_implementation,
)
else:
assert args.export_fmt == "hf", (
"Low memory mode is only supported for exporting HF checkpoint."
)
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
)
Expand Down Expand Up @@ -600,34 +583,31 @@ def output_decode(generated_ids, input_shape):
setattr(model.config, "architectures", full_model_config.architectures)

start_time = time.time()
if args.export_fmt == "tensorrt_llm":
if model_type in ["t5", "bart", "whisper"] or args.sparsity_fmt != "dense":
warnings.warn(
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime."
)

# Move meta tensor back to device before exporting.
remove_hook_from_module(model, recurse=True)

dtype = None
if "w4a8_awq" in args.qformat:
# TensorRT-LLM w4a8 only support fp16 as the dtype.
dtype = torch.float16

# For Gemma2-27B, TRT-LLM only works with bfloat16 as the dtype.
if model_type == "gemma2":
dtype = torch.bfloat16

export_tensorrt_llm_checkpoint(
model,
model_type,
dtype=dtype,
export_dir=export_path,
inference_tensor_parallel=args.inference_tensor_parallel,
inference_pipeline_parallel=args.inference_pipeline_parallel,
)
elif args.export_fmt == "hf":
else:
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
assert args.sparsity_fmt == "dense", (
f"Sparsity format {args.sparsity_fmt} not supported by unified export api."
)

export_hf_checkpoint(
full_model,
export_dir=export_path,
)
else:
raise NotImplementedError(f"{args.export_fmt} not supported")

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
Expand Down Expand Up @@ -710,9 +690,9 @@ def output_decode(generated_ids, input_shape):
parser.add_argument(
"--export_fmt",
required=False,
default="tensorrt_llm",
default="hf",
choices=["tensorrt_llm", "hf"],
help=("Checkpoint export format"),
help="Deprecated. Please avoid using this argument.",
)
parser.add_argument(
"--trust_remote_code",
Expand Down Expand Up @@ -767,6 +747,9 @@ def output_decode(generated_ids, input_shape):

args = parser.parse_args()

if args.export_fmt != "hf":
warnings.warn("Deprecated. --export_fmt will be ignored.")

args.dataset = args.dataset.split(",") if args.dataset else None
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
main(args)
Loading
Loading