Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
04a3a99
e2e example for qlora ddp export
sugunav14 Sep 22, 2025
33c90d1
updated readme
sugunav14 Sep 22, 2025
90dc505
minor refactor
sugunav14 Sep 22, 2025
2494386
minor update
sugunav14 Sep 22, 2025
3558325
updated unit test
sugunav14 Sep 22, 2025
0338c84
Minor bug fix
sugunav14 Sep 23, 2025
0c04055
Update trainer to save base model weights and config.json
sugunav14 Sep 25, 2025
2fec64d
export for fp8 lora base model
sugunav14 Sep 26, 2025
0ad6e5f
added support for nvfp4 export
sugunav14 Sep 26, 2025
beff531
e2e checkpoint tested for nvfp4 and fp8
sugunav14 Sep 29, 2025
7873a8e
cleanup
sugunav14 Sep 29, 2025
33b9a03
refactored
sugunav14 Sep 30, 2025
e20b853
minor update
sugunav14 Sep 30, 2025
fbbc96b
added requantize/resmooth for qlora export
sugunav14 Sep 30, 2025
a112fee
removed stray print
sugunav14 Sep 30, 2025
88d907e
update readme and documentation
sugunav14 Sep 30, 2025
987490e
minor fix
sugunav14 Sep 30, 2025
268c94c
added logging statements
sugunav14 Sep 30, 2025
ba7ffab
minor
sugunav14 Sep 30, 2025
6cf109f
added TODO
sugunav14 Sep 30, 2025
6697d4d
Refactor to include QAT/QAD export too
sugunav14 Oct 6, 2025
8f18a1b
Updated README
sugunav14 Oct 6, 2025
1490517
updated condition in get_quant_config
sugunav14 Oct 7, 2025
522cd98
added check for frozen base model
sugunav14 Oct 7, 2025
f7298e7
updated check
sugunav14 Oct 7, 2025
1fcb56d
Update examples/llm_qat/README.md
sugunav14 Oct 9, 2025
ac8c61c
Update examples/llm_qat/README.md
sugunav14 Oct 9, 2025
33d79e0
Update examples/llm_qat/README.md
sugunav14 Oct 13, 2025
f99e2fb
updated function name
sugunav14 Oct 20, 2025
14dad2b
changelog update
sugunav14 Oct 20, 2025
7cca3b7
updated check
sugunav14 Nov 11, 2025
676c8d8
formatting fix
sugunav14 Nov 11, 2025
a5cb105
minor fix
sugunav14 Nov 11, 2025
dbf15f1
update
sugunav14 Nov 12, 2025
fd40058
update
sugunav14 Nov 12, 2025
9c3d6bf
update
sugunav14 Nov 12, 2025
e7e5fdc
minor fix
sugunav14 Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ Model Optimizer Changelog (Linux)
- Add support for ``torch.compile`` and benchmarking in ``examples/diffusers/quantization/diffusion_trt.py``.
- Enabled native Modelopt quantization support for FP8 and NVFP4 formats in SGLang. See `SGLang quantization documentation <https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/quantization.md#using-nvidia-modelopt>`_ for more details.
- Added modelopt quantized checkpoints in vLLM/SGLang CI/CD pipelines (PRs are under review).
- Add support for exporting QLoRA checkpoint fintuned using ModelOpt.

**Documentation**

- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/pruning#pruning-guidelines>`_ for more details.
- Added example for exporting QLoRA checkpoint for vLLM deployment. Refer to `examples/llm_qat/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/79ef31bc7269ba4da0cfab446da5b64509cbfcef/examples/llm_qat/README.md#qlora-deployment>`_ for more details

0.37 (2025-10-08)
^^^^^^^^^^^^^^^^^
Expand Down
27 changes: 21 additions & 6 deletions examples/llm_qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,12 @@ See more details on running LLM evaluation benchmarks [here](../llm_eval/README.

## Deployment

The final model after QAT is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment.
The final model after QAT/QAD is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM)/TensorRT/vLLM/SGLang just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment.

To run QAT model with TRTLLM, run:
To export TRTLLM/vLLM/SGLang compatible checkpoint for the model after QAT (or QAD) model, run:

```sh
cd ../llm_ptq

./scripts/huggingface_example.sh --model ../llm_qat/llama3-qat --quant w4a8_awq
python export.py --pyt_ckpt_path llama3-qat --export_path llama3-qat-deploy
```

Note: The QAT checkpoint for `w4a8_awq` config can be created by using `--quant_cfg W4A8_AWQ_BETA_CFG` in [QAT example](#end-to-end-qat-example).
Expand Down Expand Up @@ -345,8 +343,25 @@ To perform QLoRA training, run:
--lora True
```

> **_NOTE:_** QLoRA is currently an experimental feature designed to reduce the memory footprint during training. Deployment functionality is not yet available.
## QLoRA deployment

After performing QLoRA training the final checkpoint can be exported for deployment with vLLM using the following command.

```sh
python export.py \
--pyt_ckpt_path llama3-fp4-qlora \
--export_path llama3-fp4-qlora-hf \

```

To deploy with vLLM, run the following command. For more details about QLoRA deployment using vLLM refer to the documentation [here](https://docs.vllm.ai/en/latest/features/lora.html).

```sh
vllm serve llama3-fp4-qlora-hf/base_model --enable-lora --lora-modules adapter=llama3-fp4-qlora-hf --port 8000 --tokenizer llama3-fp4-qlora-hf
```

> _Note: We currently do not support export option for QLoRA models generated using FSDP2._
>
## Pre-Quantized Checkpoints

- Ready-to-deploy checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\]
Expand Down
135 changes: 135 additions & 0 deletions examples/llm_qat/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import warnings
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import modelopt.torch.opt as mto
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
from modelopt.torch.opt.conversion import restore_from_modelopt_state
from modelopt.torch.quantization.utils import set_quantizer_state_dict
from modelopt.torch.utils import print_rank_0

RAND_SEED = 1234

# Enable automatic save/load of modelopt state huggingface checkpointing
mto.enable_huggingface_checkpointing()


def get_model(
ckpt_path: str,
device="cuda",
):
"""
Loads a QLoRA model that has been trained using modelopt trainer.
"""
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
device_map = "auto"
if device == "cpu":
device_map = "cpu"

# Load model
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)

# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")

# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
Comment on lines +53 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Verify the state restoration order matches the trainer's pattern.

The sequence here differs from the trainer's _restore_modelopt_state_with_weights method (transformers_trainer.py lines 184-190):

Trainer pattern:

  1. Load state dict
  2. Pop modelopt_state_weights
  3. Call restore_from_modelopt_state
  4. Set quantizer state dict

Current export.py pattern:

  1. Load state dict
  2. Call restore_from_modelopt_state (line 55)
  3. Pop modelopt_state_weights (line 59)
  4. Set quantizer state dict

The pop operation at line 59 occurs after restore_from_modelopt_state, which may cause the method to receive and process the weights that should be handled separately. Align the order with the trainer's proven pattern.

Apply this diff:

     # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
     if hasattr(model, "peft_config"):
         modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         restore_from_modelopt_state(model, modelopt_state)
         print_rank_0("Restored modelopt state")
 
         # Restore modelopt quantizer state dict
-        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         if modelopt_weights is not None:
             set_quantizer_state_dict(model, modelopt_weights)
             print_rank_0("Restored modelopt quantizer state dict")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 53 to 62, the restore order differs
from the trainer: pop modelopt_state_weights before calling
restore_from_modelopt_state so the weights are removed from the state passed
into restore; specifically, after loading modelopt_state, call
modelopt_state.pop("modelopt_state_weights", None) and keep the popped value in
modelopt_weights, then call restore_from_modelopt_state(model, modelopt_state),
and finally if modelopt_weights is not None call set_quantizer_state_dict(model,
modelopt_weights) and print the restored messages.


return model


def main(args):
# Load model
model = get_model(args.pyt_ckpt_path, args.device)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
is_qlora = hasattr(model, "peft_config")

# Export HF checkpoint
export_dir = Path(args.export_path)
export_dir.mkdir(parents=True, exist_ok=True)
if is_qlora:
base_model_dir = export_dir / "base_model"
base_model_dir.mkdir(parents=True, exist_ok=True)
else:
base_model_dir = export_dir

try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)

with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

# Save model
if is_qlora:
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
model.save_pretrained(export_dir)
else:
model.save_pretrained(export_dir, state_dict=post_state_dict)

config_path = f"{base_model_dir}/config.json"

config_data = model.config.to_dict()

config_data["quantization_config"] = hf_quant_config

with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)
Comment on lines +97 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use base model config for QLoRA exports.

For QLoRA models, model.config returns the PEFT model's config, but we're writing to base_model_dir/config.json (the base model's config file). This overwrites the base model config that was saved by model.base_model.save_pretrained at line 92.

Apply this diff:

         config_path = f"{base_model_dir}/config.json"
 
-        config_data = model.config.to_dict()
+        if is_qlora:
+            config_data = model.base_model.config.to_dict()
+        else:
+            config_data = model.config.to_dict()
 
         config_data["quantization_config"] = hf_quant_config
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
config_path = f"{base_model_dir}/config.json"
config_data = model.config.to_dict()
config_data["quantization_config"] = hf_quant_config
with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)
config_path = f"{base_model_dir}/config.json"
if is_qlora:
config_data = model.base_model.config.to_dict()
else:
config_data = model.config.to_dict()
config_data["quantization_config"] = hf_quant_config
with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 97 to 104, the code uses
model.config.to_dict() then writes that to the base model's config.json which
overwrites the original base model config saved earlier; instead, use the base
model's config (model.base_model.config.to_dict()) as the starting dict, add the
"quantization_config" entry, and then write that to base_model_dir/config.json
so the original base model config is preserved and augmented rather than
replaced by the PEFT/QLoRA wrapper config.


# Save tokenizer
tokenizer.save_pretrained(export_dir)

except Exception as e:
warnings.warn(
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
" can be saved with torch.save for further inspection."
)
raise e


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--pyt_ckpt_path",
help="Specify where the PyTorch checkpoint path is",
required=True,
)

parser.add_argument("--device", default="cuda")

parser.add_argument(
"--export_path",
default="exported_model",
help="Path to save the exported model",
)

args = parser.parse_args()

main(args)
43 changes: 35 additions & 8 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,19 @@ def from_quantized_weight(
raise NotImplementedError(f"quantization format {quantization} not supported")


def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict:
def postprocess_state_dict(
state_dict: dict,
maxbound: float,
quantization: str | None,
is_modelopt_qlora: bool = False,
) -> dict:
"""Filters out keys related to weight quantizers and updates KV cache related keys.

Args:
state_dict: The full model state_dict.
maxbound: The maximum bound value for the output quantizer.
quantization: The KV cache quantization format.
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.

Returns:
The filtered state_dict without unnecessary keys like '_amax' and non KV cache output quantizers.
Expand All @@ -845,6 +851,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
"input_quantizer._pre_quant_scale": "pre_quant_scale",
}
skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"]

# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
if is_modelopt_qlora:
replacements.update(
{
"base_layer.weight": "weight",
"base_layer.input_scale": "input_scale",
"base_layer.weight_scale": "weight_scale",
}
)
skip_keys.append("base_layer")

post_state_dict = {}

Expand All @@ -855,12 +873,7 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
continue

# Skip keys not related to quantizers
if (
"output_quantizer" not in key
and "_amax" not in key
and "_bias_value" not in key
and "input_quantizer._pre_quant_scale" not in key
):
if all(skip_key not in key for skip_key in skip_keys):
post_state_dict[key] = value
continue

Expand Down Expand Up @@ -911,6 +924,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
):
keys_to_delete.append(key)

# remove LoRA adapters from state dict
if is_modelopt_qlora:
for key in post_state_dict:
if "lora" in key and key not in keys_to_delete:
keys_to_delete.append(key)
# Check for tied weights and remove duplicates
seen_tensors = {}

Expand Down Expand Up @@ -1029,6 +1047,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False

def get_quant_config(
model: nn.Module,
is_modelopt_qlora: bool = False,
) -> dict[str, Any]:
"""Generate quantization config for a model.

Expand All @@ -1037,6 +1056,7 @@ def get_quant_config(

Args:
model: The PyTorch model to make config for.
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.

Returns:
Dictionary containing the quantization configuration
Expand Down Expand Up @@ -1073,7 +1093,14 @@ def get_quant_config(
or hasattr(module, quantizer_attr_names(weight_name).input_quantizer)
for weight_name in weight_names
)
if has_quantizers:

# Skip LORA module and adapters.
# ModelOpt does not currently quantize these layers in QLoRA path.
skip_layer = is_modelopt_qlora and (
hasattr(module, "base_layer") or "lora_A" in name or "lora_B" in name
)

if has_quantizers and not skip_layer:
quantization_format = get_quantization_format(module)

# For MoE expert modules, we need to extract block size from the correct weight quantizer
Expand Down
12 changes: 7 additions & 5 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,7 @@ def _export_quantized_weight(


def _export_hf_checkpoint(
model: nn.Module,
dtype: torch.dtype | None = None,
**kwargs,
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Exports the torch model to the packed checkpoint with original HF naming.

Expand Down Expand Up @@ -458,7 +456,7 @@ def _export_hf_checkpoint(
except ImportError:
warnings.warn("accelerate is not installed, hooks will not be removed")

quant_config = get_quant_config(model)
quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)

kv_cache_max_bound = 0
kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"]
Expand Down Expand Up @@ -488,6 +486,10 @@ def _export_hf_checkpoint(

fsdp_module_to_reshard = sub_module

# We skip QuantLoraLinear module for modelopt QLoRA
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue

if get_quantization_format(sub_module) != QUANTIZATION_NONE:
has_quantized_layers = True
if is_quantlinear(sub_module):
Expand Down Expand Up @@ -520,7 +522,7 @@ def _export_hf_checkpoint(
quantized_state_dict = model.state_dict()

quantized_state_dict = postprocess_state_dict(
quantized_state_dict, kv_cache_max_bound, kv_cache_format
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)

# Check if any layers are quantized
Expand Down
27 changes: 24 additions & 3 deletions modelopt/torch/quantization/plugins/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
self.model, "peft_config"
):
# TODO: use get_peft_model here instead of add_adapter
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
self.model.add_adapter(self.args.lora_config)
print_rank_0("Lora adapter added.")

if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
Expand Down Expand Up @@ -209,14 +209,16 @@ def forward_loop(model):
print_rank_0("Quantizing the model...")
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]

# Save modelopt state
self._save_modelopt_state_with_weights()

if getattr(self.quant_args, "compress", False):
print_rank_0("Compressing model after calibration")
mtq.compress(self.model)

# Force garbage collection to free up memory
gc.collect()

self._save_modelopt_state_with_weights()
torch.cuda.empty_cache()

if self.accelerator.is_main_process:
Expand Down Expand Up @@ -275,6 +277,25 @@ def save_model(self, *args, **kwargs):
outputs = super().save_model(*args, **kwargs)
return outputs

def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
if is_lora and not self.is_fsdp_enabled:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
# This custom logic only loads best adapters. Ensure base model is frozen
assert all(
not param.requires_grad
for name, param in self.model.base_model.named_parameters()
if "base_layer" in name
), "Some base_layer parameters are not frozen"

adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
else:
super()._load_best_model(*args, **kwargs)

def _patch_accelerate_for_fsdp2_fix(self):
"""Fixes for accelerate prepare.

Expand Down Expand Up @@ -337,7 +358,7 @@ def __init__(
if self.quant_cfg is not None and not is_quantized(self.model):
self._quantize_model()
if getattr(self.args, "lora_config", None) is not None:
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
self.model.add_adapter(self.args.lora_config)
print_rank_0("Lora adapter added.")
self._convert_to_distillation_model()

Expand Down
Loading