Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions examples/llm_qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,10 @@ See more details on running LLM evaluation benchmarks [here](../llm_eval/README.

The final model after QAT is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment.

To run QAT model with TRTLLM, run:
To run QAT model with vLLM/TRTLLM, run:

```sh
cd ../llm_ptq

./scripts/huggingface_example.sh --model ../llm_qat/llama3-qat --quant w4a8_awq
python export.py --pyt_ckpt_path llama3-qat --export_path llama3-qat-deploy
```

Note: The QAT checkpoint for `w4a8_awq` config can be created by using `--quant_cfg W4A8_AWQ_BETA_CFG` in [QAT example](#end-to-end-qat-example).
Expand Down Expand Up @@ -345,7 +343,22 @@ To perform QLoRA training, run:
--lora True
```

> **_NOTE:_** QLoRA is currently an experimental feature designed to reduce the memory footprint during training. Deployment functionality is not yet available.
## QLoRA deployment

After performing QLoRA training the final checkpoint can be exported for deployment with vLLM using the following command.

```sh
python export.py \
--pyt_ckpt_path llama3-fp4-qlora \
--export_path llama3-fp4-qlora-hf \

```

To deploy with vLLM, run the following command. For more details about QLoRA deployment using vLLM refer to the documentation [here](https://docs.vllm.ai/en/latest/features/lora.html).

```sh
vllm serve llama3-fp4-qlora-hf/base_model --enable-lora --lora-modules adapter=llama3-fp4-qlora-hf --port 8000 --tokenizer llama3-fp4-qlora-hf
```

## Pre-Quantized Checkpoints

Expand Down
135 changes: 135 additions & 0 deletions examples/llm_qat/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import warnings
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import modelopt.torch.opt as mto
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
from modelopt.torch.opt.conversion import restore_from_modelopt_state
from modelopt.torch.quantization.utils import set_quantizer_state_dict
from modelopt.torch.utils import print_rank_0

RAND_SEED = 1234

# Enable automatic save/load of modelopt state huggingface checkpointing
mto.enable_huggingface_checkpointing()


def get_lora_model(
ckpt_path: str,
device="cuda",
):
"""
Loads a QLoRA model that has been trained using modelopt trainer.
"""
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
device_map = "auto"
if device == "cpu":
device_map = "cpu"

# Load model
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)

# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")

# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
Comment on lines +53 to +62
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Verify the state restoration order matches the trainer's pattern.

The sequence here differs from the trainer's _restore_modelopt_state_with_weights method (transformers_trainer.py lines 184-190):

Trainer pattern:

  1. Load state dict
  2. Pop modelopt_state_weights
  3. Call restore_from_modelopt_state
  4. Set quantizer state dict

Current export.py pattern:

  1. Load state dict
  2. Call restore_from_modelopt_state (line 55)
  3. Pop modelopt_state_weights (line 59)
  4. Set quantizer state dict

The pop operation at line 59 occurs after restore_from_modelopt_state, which may cause the method to receive and process the weights that should be handled separately. Align the order with the trainer's proven pattern.

Apply this diff:

     # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
     if hasattr(model, "peft_config"):
         modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         restore_from_modelopt_state(model, modelopt_state)
         print_rank_0("Restored modelopt state")
 
         # Restore modelopt quantizer state dict
-        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         if modelopt_weights is not None:
             set_quantizer_state_dict(model, modelopt_weights)
             print_rank_0("Restored modelopt quantizer state dict")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 53 to 62, the restore order differs
from the trainer: pop modelopt_state_weights before calling
restore_from_modelopt_state so the weights are removed from the state passed
into restore; specifically, after loading modelopt_state, call
modelopt_state.pop("modelopt_state_weights", None) and keep the popped value in
modelopt_weights, then call restore_from_modelopt_state(model, modelopt_state),
and finally if modelopt_weights is not None call set_quantizer_state_dict(model,
modelopt_weights) and print the restored messages.


return model


def main(args):
# Load model
model = get_lora_model(args.pyt_ckpt_path, args.device)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
is_qlora = hasattr(model, "peft_config")

# Export HF checkpoint
export_dir = Path(args.export_path)
export_dir.mkdir(parents=True, exist_ok=True)
if is_qlora:
base_model_dir = export_dir / "base_model"
base_model_dir.mkdir(parents=True, exist_ok=True)
else:
base_model_dir = export_dir

try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)

with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

# Save model
if is_qlora:
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
model.save_pretrained(export_dir)
else:
model.save_pretrained(export_dir, state_dict=post_state_dict)

config_path = f"{base_model_dir}/config.json"

config_data = model.config.to_dict()

config_data["quantization_config"] = hf_quant_config

with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)

# Save tokenizer
tokenizer.save_pretrained(export_dir)

except Exception as e:
warnings.warn(
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
" can be saved with torch.save for further inspection."
)
raise e


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--pyt_ckpt_path",
help="Specify where the PyTorch checkpoint path is",
required=True,
)

parser.add_argument("--device", default="cuda")

parser.add_argument(
"--export_path",
default="exported_model",
help="Path to save the exported model",
)

args = parser.parse_args()

main(args)
40 changes: 33 additions & 7 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,19 @@ def from_quantized_weight(
raise NotImplementedError(f"quantization format {quantization} not supported")


def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict:
def postprocess_state_dict(
state_dict: dict,
maxbound: float,
quantization: str | None,
is_modelopt_qlora: bool = False,
) -> dict:
"""Filters out keys related to weight quantizers and updates KV cache related keys.
Args:
state_dict: The full model state_dict.
maxbound: The maximum bound value for the output quantizer.
quantization: The KV cache quantization format.
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.
Returns:
The filtered state_dict without unnecessary keys like '_amax' and non KV cache output quantizers.
Expand All @@ -833,17 +839,24 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
"input_quantizer._pre_quant_scale": "pre_quant_scale",
}
skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"]

# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
if is_modelopt_qlora:
replacements.update(
{
"base_layer.weight": "weight",
"base_layer.input_scale": "input_scale",
"base_layer.weight_scale": "weight_scale",
}
)
skip_keys.append("base_layer")

post_state_dict = {}

for key, value in state_dict.items():
# Skip keys not related to quantizers
if (
"output_quantizer" not in key
and "_amax" not in key
and "_bias_value" not in key
and "input_quantizer._pre_quant_scale" not in key
):
if all(skip_key not in key for skip_key in skip_keys):
post_state_dict[key] = value
continue

Expand Down Expand Up @@ -894,6 +907,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
):
keys_to_delete.append(key)

# remove LoRA adapters from state dict
if is_modelopt_qlora:
for key in post_state_dict:
if "lora" in key and key not in keys_to_delete:
keys_to_delete.append(key)
# Check for tied weights and remove duplicates
seen_tensors = {}

Expand Down Expand Up @@ -1072,6 +1090,14 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
if block_size == 0:
block_size = get_weight_block_size(module)

# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled
if block_size == 0 and quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a better flag instead of checking the block_size? E.g. weight_quantizer enabled vs disabled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check and update the PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
continue

# Construct per layer config dictionary
layer_config_dict[name + ".quantization"] = quantization_format
layer_config_dict[name + ".awq_block_size"] = block_size
Expand Down
10 changes: 7 additions & 3 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _export_quantized_weight(


def _export_hf_checkpoint(
model: nn.Module, dtype: torch.dtype | None = None
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Exports the torch model to the packed checkpoint with original HF naming.
Expand Down Expand Up @@ -461,7 +461,11 @@ def _export_hf_checkpoint(
for name, sub_module in layer_pool.items():
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
has_quantized_layers = True
if is_quantlinear(sub_module):
if (
is_quantlinear(sub_module)
and hasattr(sub_module, "weight_quantizer")
and sub_module.weight_quantizer.is_enabled
):
_export_quantized_weight(sub_module, dtype)
elif (
"Llama4TextExperts" in type(sub_module).__name__
Expand All @@ -485,7 +489,7 @@ def _export_hf_checkpoint(
quantized_state_dict = model.state_dict()

quantized_state_dict = postprocess_state_dict(
quantized_state_dict, kv_cache_max_bound, kv_cache_format
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)

# Check if any layers are quantized
Expand Down
20 changes: 17 additions & 3 deletions modelopt/torch/quantization/plugins/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
self.model, "peft_config"
):
# TODO: use get_peft_model here instead of add_adapter
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
self.model.add_adapter(self.args.lora_config)
print_rank_0("Lora adapter added.")

if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need disable_lora_quantizers_in_config? This does not seem warranted to me.

Expand Down Expand Up @@ -209,14 +209,16 @@ def forward_loop(model):
print_rank_0("Quantizing the model...")
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]

# Save modelopt state
self._save_modelopt_state_with_weights()

if getattr(self.quant_args, "compress", False):
print_rank_0("Compressing model after calibration")
mtq.compress(self.model)

# Force garbage collection to free up memory
gc.collect()

self._save_modelopt_state_with_weights()
torch.cuda.empty_cache()

if self.accelerator.is_main_process:
Expand Down Expand Up @@ -275,6 +277,18 @@ def save_model(self, *args, **kwargs):
outputs = super().save_model(*args, **kwargs)
return outputs

def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
if is_lora and not self.is_fsdp_enabled:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, does this only load the lora adapter? Maybe we need to add a check to make sure the base model is expected to be frozen/compressed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this loads only the best lora adapter. The reason I introduced this logic is our current workflow seems slightly incompatible with HF trainer as we are using .add_adapter() instead of get_peft_model() due to which HF trainer doesn't detect it as a peft model. This is causing some errors in the final load_best_checkpoint() call so I added this temporary fix until we migrate to using get_peft_model().

Currently I execute this logic only if compress is enabled and fsdp2 is not enabled (indicating DDP QLoRA). Do you recommend any other checks?

else:
super()._load_best_model(*args, **kwargs)
Comment on lines +280 to +290
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling and validate checkpoint existence.

The custom LoRA loading logic lacks safety checks:

  1. No validation that self.state.best_model_checkpoint exists before attempting to load
  2. No error handling if load_adapter fails
  3. The condition checks is_lora and not self.is_fsdp_enabled, but an earlier comment thread mentioned this should only execute "if compress is enabled and fsdp2 is not enabled." The current logic doesn't check the compress flag.

Add defensive checks:

 def _load_best_model(self, *args, **kwargs):
     """Load the best model for final evaluation."""
     is_lora = getattr(self.args, "lora", None)
-    if is_lora and not self.is_fsdp_enabled:
+    is_compressed = getattr(self.quant_args, "compress", False)
+    if is_lora and not self.is_fsdp_enabled and is_compressed:
         # Custom logic for loading best model with LoRA
         # TODO: Remove once we migrate to using get_peft_model()
+        if not self.state.best_model_checkpoint:
+            print_rank_0("No best model checkpoint found, skipping adapter reload")
+            return
+        try:
-            adapter_name = self.model.active_adapter()
-            self.model.delete_adapter(adapter_name)
-            self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+            adapter_name = self.model.active_adapter()
+            self.model.delete_adapter(adapter_name)
+            self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+            print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
+        except Exception as e:
+            print_rank_0(f"Failed to load best adapter: {e}")
+            raise
     else:
         super()._load_best_model(*args, **kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
if is_lora and not self.is_fsdp_enabled:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
else:
super()._load_best_model(*args, **kwargs)
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
is_compressed = getattr(self.quant_args, "compress", False)
if is_lora and not self.is_fsdp_enabled and is_compressed:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
if not self.state.best_model_checkpoint:
print_rank_0("No best model checkpoint found, skipping adapter reload")
return
try:
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
except Exception as e:
print_rank_0(f"Failed to load best adapter: {e}")
raise
else:
super()._load_best_model(*args, **kwargs)


def _patch_accelerate_for_fsdp2_fix(self):
"""Fixes for accelerate prepare.
Expand Down Expand Up @@ -337,7 +351,7 @@ def __init__(
if self.quant_cfg is not None and not is_quantized(self.model):
self._quantize_model()
if getattr(self.args, "lora_config", None) is not None:
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
self.model.add_adapter(self.args.lora_config)
print_rank_0("Lora adapter added.")
self._convert_to_distillation_model()

Expand Down
1 change: 0 additions & 1 deletion tests/examples/llm_qat/test_llm_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_llama_lora_qat_nvfp4(tiny_llama_path, tmp_path):
)


@pytest.mark.skip(reason="Fix QLoRa test failure")
def test_llama_qlora_nvfp4(tiny_llama_path, tmp_path):
_run_command(
[
Expand Down