Skip to content

Commit 7d0f7a9

Browse files
sugunav14realAsma
andauthored
QLoRA DDP export (#353)
## What does this PR do? **Type of change:** New example <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** This PR provides an e2e example for fine-tuning a model using QLoRA with DDP and exporting checkpoint for deployment using vllm. 1. This PR contains a temporary fix for loading best checkpoint in the end for DDP which can be removed once we move to using get_peft_model() 2. The final base checkpoint is exported under output_dir/base_model while the adapter weights are exported under output_dir ## Usage <!-- You can potentially add a usage example below. --> Refer to README.md changes ## Testing <!-- Mention how have you tested your change if applicable. --> Trainer - [x] `./launch.sh --model meta-llama/Meta-Llama-3-8B --num_epochs 0.01 --lr 1e-3 --do_train True --output_dir test --quant_cfg FP8_DEFAULT_CFG --compress True --lora True` Export - [x] `python export.py --pyt_ckpt_path test --export_dir test-fp8 ` Deployment - [x] `vllm serve test-fp8/base_model --enable-lora --lora-modules sql-lora=test-fp8 --port 8090 --tokenizer test-fp8` - [x] e2e unit test - [ ] Sanity check weights, dtypes of generated checkpoint - [x] Test phi4 ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added an export CLI/tool to produce HuggingFace-ready checkpoints from LoRA/QLoRA-trained models, including optional restoration of quantizer state. * **Improvements** * Export now respects a QLoRA mode, filters and strips adapter entries appropriately, and only emits per-layer quantization when weight quantization is enabled. Saves model state earlier after quantization and tightens checks for exporting quantized weights. * **Documentation** * Expanded LLM QAT README with QLoRA export and deployment guidance. * **Tests** * Re-enabled a previously skipped QLoRA test. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Suguna Velury <[email protected]> Signed-off-by: sugunav14 <[email protected]> Co-authored-by: realAsma <[email protected]>
1 parent b1fc1fe commit 7d0f7a9

File tree

7 files changed

+224
-23
lines changed

7 files changed

+224
-23
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ Model Optimizer Changelog (Linux)
3535
- Add support for ``torch.compile`` and benchmarking in ``examples/diffusers/quantization/diffusion_trt.py``.
3636
- Enabled native Modelopt quantization support for FP8 and NVFP4 formats in SGLang. See `SGLang quantization documentation <https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/quantization.md#using-nvidia-modelopt>`_ for more details.
3737
- Added modelopt quantized checkpoints in vLLM/SGLang CI/CD pipelines (PRs are under review).
38+
- Add support for exporting QLoRA checkpoint fintuned using ModelOpt.
3839

3940
**Documentation**
4041

4142
- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/pruning#pruning-guidelines>`_ for more details.
43+
- Added example for exporting QLoRA checkpoint for vLLM deployment. Refer to `examples/llm_qat/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/79ef31bc7269ba4da0cfab446da5b64509cbfcef/examples/llm_qat/README.md#qlora-deployment>`_ for more details
4244

4345
0.37 (2025-10-08)
4446
^^^^^^^^^^^^^^^^^

examples/llm_qat/README.md

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,12 @@ See more details on running LLM evaluation benchmarks [here](../llm_eval/README.
301301

302302
## Deployment
303303

304-
The final model after QAT is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment.
304+
The final model after QAT/QAD is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM)/TensorRT/vLLM/SGLang just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment.
305305

306-
To run QAT model with TRTLLM, run:
306+
To export TRTLLM/vLLM/SGLang compatible checkpoint for the model after QAT (or QAD) model, run:
307307

308308
```sh
309-
cd ../llm_ptq
310-
311-
./scripts/huggingface_example.sh --model ../llm_qat/llama3-qat --quant w4a8_awq
309+
python export.py --pyt_ckpt_path llama3-qat --export_path llama3-qat-deploy
312310
```
313311

314312
Note: The QAT checkpoint for `w4a8_awq` config can be created by using `--quant_cfg W4A8_AWQ_BETA_CFG` in [QAT example](#end-to-end-qat-example).
@@ -345,8 +343,25 @@ To perform QLoRA training, run:
345343
--lora True
346344
```
347345

348-
> **_NOTE:_** QLoRA is currently an experimental feature designed to reduce the memory footprint during training. Deployment functionality is not yet available.
346+
## QLoRA deployment
347+
348+
After performing QLoRA training the final checkpoint can be exported for deployment with vLLM using the following command.
349+
350+
```sh
351+
python export.py \
352+
--pyt_ckpt_path llama3-fp4-qlora \
353+
--export_path llama3-fp4-qlora-hf \
354+
355+
```
356+
357+
To deploy with vLLM, run the following command. For more details about QLoRA deployment using vLLM refer to the documentation [here](https://docs.vllm.ai/en/latest/features/lora.html).
358+
359+
```sh
360+
vllm serve llama3-fp4-qlora-hf/base_model --enable-lora --lora-modules adapter=llama3-fp4-qlora-hf --port 8000 --tokenizer llama3-fp4-qlora-hf
361+
```
349362

363+
> _Note: We currently do not support export option for QLoRA models generated using FSDP2._
364+
>
350365
## Pre-Quantized Checkpoints
351366

352367
- Ready-to-deploy checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\]

examples/llm_qat/export.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import json
18+
import warnings
19+
from pathlib import Path
20+
21+
import torch
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
import modelopt.torch.opt as mto
25+
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
26+
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
27+
from modelopt.torch.opt.conversion import restore_from_modelopt_state
28+
from modelopt.torch.quantization.utils import set_quantizer_state_dict
29+
from modelopt.torch.utils import print_rank_0
30+
31+
RAND_SEED = 1234
32+
33+
# Enable automatic save/load of modelopt state huggingface checkpointing
34+
mto.enable_huggingface_checkpointing()
35+
36+
37+
def get_model(
38+
ckpt_path: str,
39+
device="cuda",
40+
):
41+
"""
42+
Loads a QLoRA model that has been trained using modelopt trainer.
43+
"""
44+
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
45+
device_map = "auto"
46+
if device == "cpu":
47+
device_map = "cpu"
48+
49+
# Load model
50+
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
51+
52+
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
53+
if hasattr(model, "peft_config"):
54+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
55+
restore_from_modelopt_state(model, modelopt_state)
56+
print_rank_0("Restored modelopt state")
57+
58+
# Restore modelopt quantizer state dict
59+
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
60+
if modelopt_weights is not None:
61+
set_quantizer_state_dict(model, modelopt_weights)
62+
print_rank_0("Restored modelopt quantizer state dict")
63+
64+
return model
65+
66+
67+
def main(args):
68+
# Load model
69+
model = get_model(args.pyt_ckpt_path, args.device)
70+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
71+
is_qlora = hasattr(model, "peft_config")
72+
73+
# Export HF checkpoint
74+
export_dir = Path(args.export_path)
75+
export_dir.mkdir(parents=True, exist_ok=True)
76+
if is_qlora:
77+
base_model_dir = export_dir / "base_model"
78+
base_model_dir.mkdir(parents=True, exist_ok=True)
79+
else:
80+
base_model_dir = export_dir
81+
82+
try:
83+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)
84+
85+
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
86+
json.dump(hf_quant_config, file, indent=4)
87+
88+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
89+
90+
# Save model
91+
if is_qlora:
92+
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
93+
model.save_pretrained(export_dir)
94+
else:
95+
model.save_pretrained(export_dir, state_dict=post_state_dict)
96+
97+
config_path = f"{base_model_dir}/config.json"
98+
99+
config_data = model.config.to_dict()
100+
101+
config_data["quantization_config"] = hf_quant_config
102+
103+
with open(config_path, "w") as file:
104+
json.dump(config_data, file, indent=4)
105+
106+
# Save tokenizer
107+
tokenizer.save_pretrained(export_dir)
108+
109+
except Exception as e:
110+
warnings.warn(
111+
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
112+
" can be saved with torch.save for further inspection."
113+
)
114+
raise e
115+
116+
117+
if __name__ == "__main__":
118+
parser = argparse.ArgumentParser(description=__doc__)
119+
parser.add_argument(
120+
"--pyt_ckpt_path",
121+
help="Specify where the PyTorch checkpoint path is",
122+
required=True,
123+
)
124+
125+
parser.add_argument("--device", default="cuda")
126+
127+
parser.add_argument(
128+
"--export_path",
129+
default="exported_model",
130+
help="Path to save the exported model",
131+
)
132+
133+
args = parser.parse_args()
134+
135+
main(args)

modelopt/torch/export/quant_utils.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -827,13 +827,19 @@ def from_quantized_weight(
827827
raise NotImplementedError(f"quantization format {quantization} not supported")
828828

829829

830-
def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict:
830+
def postprocess_state_dict(
831+
state_dict: dict,
832+
maxbound: float,
833+
quantization: str | None,
834+
is_modelopt_qlora: bool = False,
835+
) -> dict:
831836
"""Filters out keys related to weight quantizers and updates KV cache related keys.
832837
833838
Args:
834839
state_dict: The full model state_dict.
835840
maxbound: The maximum bound value for the output quantizer.
836841
quantization: The KV cache quantization format.
842+
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.
837843
838844
Returns:
839845
The filtered state_dict without unnecessary keys like '_amax' and non KV cache output quantizers.
@@ -845,6 +851,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
845851
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
846852
"input_quantizer._pre_quant_scale": "pre_quant_scale",
847853
}
854+
skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"]
855+
856+
# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
857+
if is_modelopt_qlora:
858+
replacements.update(
859+
{
860+
"base_layer.weight": "weight",
861+
"base_layer.input_scale": "input_scale",
862+
"base_layer.weight_scale": "weight_scale",
863+
}
864+
)
865+
skip_keys.append("base_layer")
848866

849867
post_state_dict = {}
850868

@@ -855,12 +873,7 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
855873
continue
856874

857875
# Skip keys not related to quantizers
858-
if (
859-
"output_quantizer" not in key
860-
and "_amax" not in key
861-
and "_bias_value" not in key
862-
and "input_quantizer._pre_quant_scale" not in key
863-
):
876+
if all(skip_key not in key for skip_key in skip_keys):
864877
post_state_dict[key] = value
865878
continue
866879

@@ -911,6 +924,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
911924
):
912925
keys_to_delete.append(key)
913926

927+
# remove LoRA adapters from state dict
928+
if is_modelopt_qlora:
929+
for key in post_state_dict:
930+
if "lora" in key and key not in keys_to_delete:
931+
keys_to_delete.append(key)
914932
# Check for tied weights and remove duplicates
915933
seen_tensors = {}
916934

@@ -1029,6 +1047,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
10291047

10301048
def get_quant_config(
10311049
model: nn.Module,
1050+
is_modelopt_qlora: bool = False,
10321051
) -> dict[str, Any]:
10331052
"""Generate quantization config for a model.
10341053
@@ -1037,6 +1056,7 @@ def get_quant_config(
10371056
10381057
Args:
10391058
model: The PyTorch model to make config for.
1059+
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.
10401060
10411061
Returns:
10421062
Dictionary containing the quantization configuration
@@ -1073,7 +1093,14 @@ def get_quant_config(
10731093
or hasattr(module, quantizer_attr_names(weight_name).input_quantizer)
10741094
for weight_name in weight_names
10751095
)
1076-
if has_quantizers:
1096+
1097+
# Skip LORA module and adapters.
1098+
# ModelOpt does not currently quantize these layers in QLoRA path.
1099+
skip_layer = is_modelopt_qlora and (
1100+
hasattr(module, "base_layer") or "lora_A" in name or "lora_B" in name
1101+
)
1102+
1103+
if has_quantizers and not skip_layer:
10771104
quantization_format = get_quantization_format(module)
10781105

10791106
# For MoE expert modules, we need to extract block size from the correct weight quantizer

modelopt/torch/export/unified_export_hf.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ def _export_quantized_weight(
365365

366366

367367
def _export_hf_checkpoint(
368-
model: nn.Module,
369-
dtype: torch.dtype | None = None,
370-
**kwargs,
368+
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
371369
) -> tuple[dict[str, Any], dict[str, Any]]:
372370
"""Exports the torch model to the packed checkpoint with original HF naming.
373371
@@ -458,7 +456,7 @@ def _export_hf_checkpoint(
458456
except ImportError:
459457
warnings.warn("accelerate is not installed, hooks will not be removed")
460458

461-
quant_config = get_quant_config(model)
459+
quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)
462460

463461
kv_cache_max_bound = 0
464462
kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"]
@@ -488,6 +486,10 @@ def _export_hf_checkpoint(
488486

489487
fsdp_module_to_reshard = sub_module
490488

489+
# We skip QuantLoraLinear module for modelopt QLoRA
490+
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
491+
continue
492+
491493
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
492494
has_quantized_layers = True
493495
if is_quantlinear(sub_module):
@@ -520,7 +522,7 @@ def _export_hf_checkpoint(
520522
quantized_state_dict = model.state_dict()
521523

522524
quantized_state_dict = postprocess_state_dict(
523-
quantized_state_dict, kv_cache_max_bound, kv_cache_format
525+
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
524526
)
525527

526528
# Check if any layers are quantized

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(
146146
self.model, "peft_config"
147147
):
148148
# TODO: use get_peft_model here instead of add_adapter
149-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
149+
self.model.add_adapter(self.args.lora_config)
150150
print_rank_0("Lora adapter added.")
151151

152152
if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
@@ -209,14 +209,16 @@ def forward_loop(model):
209209
print_rank_0("Quantizing the model...")
210210
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]
211211

212+
# Save modelopt state
213+
self._save_modelopt_state_with_weights()
214+
212215
if getattr(self.quant_args, "compress", False):
213216
print_rank_0("Compressing model after calibration")
214217
mtq.compress(self.model)
215218

216219
# Force garbage collection to free up memory
217220
gc.collect()
218221

219-
self._save_modelopt_state_with_weights()
220222
torch.cuda.empty_cache()
221223

222224
if self.accelerator.is_main_process:
@@ -275,6 +277,25 @@ def save_model(self, *args, **kwargs):
275277
outputs = super().save_model(*args, **kwargs)
276278
return outputs
277279

280+
def _load_best_model(self, *args, **kwargs):
281+
"""Load the best model for final evaluation."""
282+
is_lora = getattr(self.args, "lora", None)
283+
if is_lora and not self.is_fsdp_enabled:
284+
# Custom logic for loading best model with LoRA
285+
# TODO: Remove once we migrate to using get_peft_model()
286+
# This custom logic only loads best adapters. Ensure base model is frozen
287+
assert all(
288+
not param.requires_grad
289+
for name, param in self.model.base_model.named_parameters()
290+
if "base_layer" in name
291+
), "Some base_layer parameters are not frozen"
292+
293+
adapter_name = self.model.active_adapter()
294+
self.model.delete_adapter(adapter_name)
295+
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
296+
else:
297+
super()._load_best_model(*args, **kwargs)
298+
278299
def _patch_accelerate_for_fsdp2_fix(self):
279300
"""Fixes for accelerate prepare.
280301
@@ -337,7 +358,7 @@ def __init__(
337358
if self.quant_cfg is not None and not is_quantized(self.model):
338359
self._quantize_model()
339360
if getattr(self.args, "lora_config", None) is not None:
340-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
361+
self.model.add_adapter(self.args.lora_config)
341362
print_rank_0("Lora adapter added.")
342363
self._convert_to_distillation_model()
343364

0 commit comments

Comments
 (0)