Skip to content

Conversation

sugunav14
Copy link

@sugunav14 sugunav14 commented Sep 22, 2025

What does this PR do?

Type of change: New example

Overview: This PR provides an e2e example for fine-tuning a model using QLoRA with DDP and exporting checkpoint for deployment using vllm.

  1. This PR contains a temporary fix for loading best checkpoint in the end for DDP which can be removed once we move to using get_peft_model()
  2. The final base checkpoint is exported under output_dir/base_model while the adapter weights are exported under output_dir

Usage

Refer to README.md changes

Testing

Trainer

  • ./launch.sh --model meta-llama/Meta-Llama-3-8B --num_epochs 0.01 --lr 1e-3 --do_train True --output_dir test --quant_cfg FP8_DEFAULT_CFG --compress True --lora True

Export

  • python export.py --pyt_ckpt_path test --export_dir test-fp8

Deployment

  • vllm serve test-fp8/base_model --enable-lora --lora-modules sql-lora=test-fp8 --port 8090 --tokenizer test-fp8

  • e2e unit test

  • Sanity check weights, dtypes of generated checkpoint

  • Test phi4

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added an export CLI/flow to produce Hugging Face–compatible checkpoints from QLoRA/quantized models, with optional restoration of quantizer state and adapters.
  • Improvements

    • Export now skips non-quantizer params, normalizes per-layer scales (with warnings/clamping for certain formats), removes unused adapters and tied weights, and selectively includes quantized layers for deployment.
    • Trainer saves a calibration snapshot earlier and adds custom best-model loading for LoRA cases.
  • Documentation

    • Updated QLoRA guide with export and vLLM serving instructions.
  • Tests

    • Re-enabled the QLoRA NVFP4 test.

Copy link

copy-pr-bot bot commented Sep 22, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 22, 2025

Walkthrough

Adds QLoRA-aware export and training-state handling: updates quant export utilities and HF unified export to handle LoRA-specific keys/quantizers, saves/restores ModelOpt calibration state, provides an export script for QLoRA models, updates docs for vLLM deployment, and re-enables a QLoRA NVFP4 test.

Changes

Cohort / File(s) Summary
Quant export utilities
modelopt/torch/export/quant_utils.py
postprocess_state_dict signature adds is_modelopt_qlora: bool = False; introduces skip_keys filtering and QLoRA-specific key mappings (e.g., base_layer.*weight/*_scale); two-phase key processing; _amax normalization/clamping per-format with warnings; removal of LoRA adapters when is_modelopt_qlora=True; detection/removal of tied weights; get_quant_config guards against emitting per-layer configs when block_size == 0 for NVFP4-related formats.
HF unified export
modelopt/torch/export/unified_export_hf.py
_export_hf_checkpoint signature updated with is_modelopt_qlora: bool = False; quantized-weight export now requires an enabled weight_quantizer on submodules; forwards is_modelopt_qlora to postprocess_state_dict.
Transformers trainer plugin
modelopt/torch/quantization/plugins/transformers_trainer.py
Standardizes LoRA adapter creation (add_adapter(self.args.lora_config)), saves modelopt_state_calib.pth immediately after calibration (captures quantizer weights) and removes a duplicate save; adds _load_best_model(self, *args, **kwargs) overrides in QATTrainer and QADTrainer to handle LoRA best-model loading when FSDP is not enabled.
QLoRA export script
examples/llm_qat/export.py
New script adding get_lora_model(ckpt_path, device="cuda") and main(args) to load QLoRA checkpoints, restore modelopt calibration state and optional quantizer weights, call _export_hf_checkpoint(..., is_modelopt_qlora=True), and write HF-compatible model artifacts, quant config, adapters, and tokenizer.
Docs update
examples/llm_qat/README.md
Updates deployment guidance: adds export.py usage and vLLM/TRTLLM deployment instructions for both QAT and QLoRA flows; replaces prior note that QLoRA deployment was unavailable.
Test enablement
tests/examples/llm_qat/test_llm_qat.py
Removed @pytest.mark.skip from test_llama_qlora_nvfp4, enabling the test run.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant ExportScript as examples/llm_qat/export.py
  participant Model as QLoRA Model
  participant CalibState as modelopt_state_calib.pth
  participant Exporter as _export_hf_checkpoint
  participant Utils as postprocess_state_dict

  User->>ExportScript: run main(args)
  ExportScript->>Model: load base model + LoRA adapters
  ExportScript->>CalibState: load modelopt_state_calib.pth (if present)
  CalibState-->>Model: restore calibration & optional quantizer weights
  ExportScript->>Exporter: _export_hf_checkpoint(model, is_modelopt_qlora=True)
  Exporter->>Utils: postprocess_state_dict(state_dict, maxbound, quant, is_modelopt_qlora=True)
  Utils-->>Exporter: processed_state_dict + hf_quant_config
  Exporter-->>ExportScript: return artifacts
  ExportScript->>User: write model, adapters, quant config, tokenizer
Loading
sequenceDiagram
  autonumber
  participant Trainer as QATTrainer
  participant Model
  participant FS as Filesystem
  participant Loader as _load_best_model

  Trainer->>Model: calibrate quantizers
  Trainer->>FS: save modelopt_state_calib.pth
  alt load best model (no FSDP, LoRA present)
    Trainer->>Loader: _load_best_model(...)
    Loader-->>Trainer: custom best-model load (handle adapters)
  else
    Trainer->>Loader: delegate to superclass
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

I nibble at keys and stitch each scale,
LoRA leaves tucked in a quantized tale.
KV bounds trimmed, tied weights set right,
I hop and I save the calib through the night.
Export packed—hop on, vLLM takes flight! 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the primary change by referencing QLoRA export in a DDP context, clearly summarizing the main feature added without extraneous detail. It directly relates to the core functionality introduced—exporting QLoRA models trained with Distributed Data Parallel—and is concise and specific.
Docstring Coverage ✅ Passed Docstring coverage is 91.67% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch svelury/qlora-ddp-export

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@sugunav14 sugunav14 force-pushed the svelury/qlora-ddp-export branch from 035117f to 6254cad Compare September 22, 2025 15:46
@sugunav14 sugunav14 changed the title QLoRA DDP export Draft: QLoRA DDP export Sep 22, 2025
@sugunav14 sugunav14 changed the title Draft: QLoRA DDP export QLoRA DDP export Sep 22, 2025
@sugunav14 sugunav14 marked this pull request as draft September 22, 2025 15:49
@sugunav14 sugunav14 self-assigned this Sep 22, 2025
Copy link

codecov bot commented Sep 22, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.79%. Comparing base (340eb7a) to head (f5f91ab).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #353      +/-   ##
==========================================
- Coverage   73.79%   73.79%   -0.01%     
==========================================
  Files         171      171              
  Lines       17591    17591              
==========================================
- Hits        12982    12981       -1     
- Misses       4609     4610       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sugunav14 sugunav14 marked this pull request as ready for review September 22, 2025 16:19
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)

743-745: Verify dtype casting behavior and potential precision loss.

The dtype casting is applied unconditionally when a dtype is provided, which could lead to precision loss or unexpected behavior if the original weight dtype is more precise than the target dtype.

Consider adding a warning or validation:

 if dtype:
+    if weight.dtype != dtype:
+        logger.info(f"Converting weight from {weight.dtype} to {dtype}")
     weight = weight.to(dtype)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 74061f5 and c3a883c.

📒 Files selected for processing (5)
  • examples/llm_qat/README.md (1 hunks)
  • examples/llm_qat/main.py (1 hunks)
  • modelopt/torch/export/quant_utils.py (5 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (503-556)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/export/quant_utils.py (2)
  • maybe_transpose_expert_weight_dimensions (91-120)
  • to_quantized_weight (724-790)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (389-391)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
examples/llm_qat/main.py (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
  • export_base_model (291-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: linux
🔇 Additional comments (14)
modelopt/torch/export/quant_utils.py (4)

730-730: LGTM! Optional dtype parameter added correctly.

The addition of the optional dtype parameter follows good API design practices with a sensible default of None.


833-836: LGTM! Base layer key mappings added appropriately.

The new base layer mappings support QLoRA export workflows by handling the transformation from LoRA-specific keys to standard model keys.


847-847: LGTM! Proper exclusion of base_layer keys.

The condition correctly filters out base_layer keys from the main processing loop, which is consistent with the new mappings approach.


899-903: LGTM! LoRA adapter cleanup implemented correctly.

The LoRA adapter removal logic is properly implemented to clean up adapter-specific parameters from the exported state dict, ensuring a clean base model export.

examples/llm_qat/main.py (1)

276-278: LGTM! QLoRA export integration implemented correctly.

The conditional call to trainer.export_base_model() is properly guarded by both the LoRA and compression flags, ensuring the base model is only exported when appropriate for QLoRA workflows.

modelopt/torch/export/unified_export_hf.py (5)

88-91: LGTM! Early return for LoRA models prevents unnecessary processing.

The early return correctly skips processing for LoRA-finetuned models by detecting the presence of a base_model attribute, avoiding potential issues with the requantize/resmooth operations.


329-336: Consistent dtype parameter usage.

The non-NVFP4 quantization path correctly passes the dtype parameter to to_quantized_weight, maintaining consistency with the NVFP4 path above.


465-470: Enhanced guard conditions for quantized weight export.

The additional check for hasattr(sub_module, "weight_quantizer") and sub_module.weight_quantizer.is_enabled provides better safety by ensuring quantizers exist and are active before attempting export.


531-536: LGTM! Proper base model export for QLoRA models.

The logic correctly identifies QLoRA models by checking for the base_model attribute and exports the underlying base model instead of the wrapper, which is essential for proper deployment.


317-323: Resolved — internal dtype cast remains; no action needed.
quant_utils.py performs if dtype: weight = weight.to(dtype) (modelopt/torch/export/quant_utils.py lines 743–744), so removing the pre-cast in unified_export_hf.py does not change quantization behavior.

modelopt/torch/quantization/plugins/transformers_trainer.py (3)

31-31: LGTM! Required import added.

The import of export_hf_checkpoint is correctly added to support the new export functionality.


279-290: LoRA-specific best model loading logic implemented correctly.

The implementation correctly handles the difference between LoRA and non-LoRA models. For LoRA models, it properly removes and re-loads the adapter from the best checkpoint path.

Note: The TODO comment indicates this is temporary until get_peft_model() is used, which aligns with the PR description mentioning temporary fixes.


291-296: Simple and effective base model export.

The implementation correctly calls export_hf_checkpoint with the appropriate output directory structure, and the main process check ensures only one process performs the export.

examples/llm_qat/README.md (1)

357-362: Fix vLLM serve example: point --tokenizer to the base/merged model (keep --lora-modules syntax)

  • Location: examples/llm_qat/README.md (lines 357–362).
  • Replace the example so the served model is the merged or original base model and --tokenizer points to the base-model (or HF name). The current --lora-modules adapter=llama3-fp4-qlora is correct.
  • Suggested command (use actual paths/names from your repo):
    vllm serve --enable-lora --lora-modules adapter= --tokenizer --tokenizer-mode auto --trust-remote-code --port 8000
  • Verify tokenizer.pad_token_id == model.config.pad_token_id (or set pad_token = eos_token) to avoid generation/padding issues. Repo check found no base_model directory—confirm exact paths before committing this change.

@sugunav14 sugunav14 requested a review from meenchen September 22, 2025 20:59
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/quant_utils.py (1)

732-747: Guard the new dtype pre-cast to prevent fp8/integer pitfalls

Casting to an arbitrary dtype before quantization can raise “Promotion for Float8 Types is not supported” or produce invalid math if an integer/bool dtype is passed. Restrict to safe float dtypes and reject fp8/ints.

Apply this diff:

 def to_quantized_weight(
     weight: torch.Tensor,
     weights_scaling_factor: torch.Tensor,
     quantization: str,
     weights_scaling_factor2: torch.Tensor | None = None,
     block_size: int | None = None,
     dtype: torch.dtype | None = None,
 ):
@@
-    if dtype:
-        weight = weight.to(dtype)
+    if dtype is not None:
+        # Only allow >=16-bit float dtypes here; fp8 and non-floats break downstream ops.
+        allowed_dtypes = {torch.float32, torch.float16, torch.bfloat16}
+        disallowed = {getattr(torch, "float8_e4m3fn", None)}
+        if dtype in disallowed or dtype not in allowed_dtypes:
+            raise ValueError(f"Unsupported pre-quant cast dtype: {dtype}. Use float32/float16/bfloat16.")
+        weight = weight.to(dtype)
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)

611-612: Replace commented-out debug print with logger.debug or remove

Keep logs consistent and avoid dead code.

Apply this diff:

-        # print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}")
+        logger.debug("Processing layer %s with quantization=%s, block_size=%s", k, v, block_size_value)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8b97a10 and 7b7188e.

📒 Files selected for processing (1)
  • modelopt/torch/export/quant_utils.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/export/quant_utils.py (1)

835-853: Base-layer key remap: avoid silent drops and verify collisions

Unmapped base_layer.* keys are dropped by design. If other base attributes exist (e.g., bias), they’ll vanish silently, and remaps may overwrite existing top-level keys.

Please confirm:

  • All required base_layer.* fields are covered by replacements.
  • Remapped targets (e.g., “weight”, “input_scale”, “weight_scale”) won’t collide with already-present keys.

If needed, I can add a warning when a base_layer key is encountered but not remapped. Want a patch?

@sugunav14 sugunav14 marked this pull request as draft September 23, 2025 02:43
Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall implementation looks good. However we dont have to combine the export for QLoRA with transformer_trainer. We should do the the export via hf_ptq.py

@sugunav14
Copy link
Author

Overall implementation looks good. However we dont have to combine the export for QLoRA with transformer_trainer. We should do the the export via hf_ptq.py

Okay, I will try refactoring the PR to do that! Thank you!

@realAsma
Copy link
Contributor

@sugunav14 sounds good, for example here is how we support regular qat deployment - #353 (comment)

I am thinking we should have something like:

./scripts/huggingface_example.sh --model $BASE_MODEL_PATH --adapter-path $ADAPTER_PATH --quant w4a8_awq

For deployment of QAT and QLoRA checkpoint, we still need to specify quant. This seems more cumbersome for users to me. If the model is already quantized such as in QAT/QLoRA checkpoints, we already skip quantization stage here -

model_is_already_quantized = is_quantized(model)

For QAT/QLoRA, can we support the following usage:

./scripts/huggingface_example.sh --model $BASE_MODEL_PATH --adapter-path $ADAPTER_PATH

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test with Phi4-multimodal-instruct export (FP8 and NVFP4) and make sure the before and after the change the safetensors are the same?

To quant Phi4-multimodal-instruct, you need to :

  1. Download https://huggingface.co/microsoft/Phi-4-multimodal-instruct
  2. modify https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2101 and enforce using InputMode.LANGUAGE.
  3. use transformers 4.48
  4. Run FP8 and NVFP4 PTQ
  5. Compare the generated safetensors with https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP4 and https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP8. Make sure the tensor keys are the same.

# Get the corresponding AWQ block size
block_size_value = layer_config_dict.get(awq_key, 0)

# print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

if isinstance(weight, QTensorWrapper):
return weight.data

if dtype:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a case where we need to cast the weights?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, seems like in the current unified export logic we do perform a cast before quantizing the weights

keys_to_delete.append(key)

# remove LoRA adapters from state dict
for key, value in post_state_dict.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the original model has lora adapters like phi4-multimodal?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not think of that case! Thanks for the catch!

layer_config_dict[name + ".quantization"] = quantization_format
layer_config_dict[name + ".awq_block_size"] = block_size
# Handles case if default weight quantizer is not enabled or is None
if block_size != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this impact per tensor quant like fp8?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will, just updated the condition.

@sugunav14 sugunav14 marked this pull request as ready for review September 30, 2025 16:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (9)
examples/llm_qat/export.py (7)

32-56: Add error handling for missing modelopt state file.

Line 47 loads modelopt_state_calibration.pth without checking if it exists. If the file is missing (e.g., model was not trained with QLoRA or was trained with an older version), torch.load will raise a FileNotFoundError with an unclear error message.

Consider adding explicit validation:

+from pathlib import Path
+
 def get_lora_model(
     ckpt_path: str,
     device="cuda",
 ):
     """
     Loads a QLoRA model that has been trained using modelopt trainer.
     """
+    # Validate modelopt state file exists
+    modelopt_state_path = Path(ckpt_path) / "modelopt_state_calibration.pth"
+    if not modelopt_state_path.exists():
+        raise FileNotFoundError(
+            f"modelopt_state_calibration.pth not found in {ckpt_path}. "
+            "Ensure the model was trained with QLoRA using the modelopt trainer."
+        )
+
     device_map = "auto"
     if device == "cpu":
         device_map = "cpu"
 
     # Load model with adapters
     model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
 
     # Restore modelopt state
-    modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False)
+    modelopt_state = torch.load(modelopt_state_path, weights_only=False)
     restore_from_modelopt_state(model, modelopt_state)

53-53: Consider using print_rank_0 for distributed consistency.

The print statement on line 53 will execute on all ranks in a distributed setting, potentially causing log clutter. If this script may be used in a distributed context, consider using print_rank_0 from modelopt.torch.utils.logging (already used in transformers_trainer.py).


87-87: Remove or conditionalize debug print of config data.

Line 87 prints the entire config_data dictionary, which can be very verbose. This appears to be debug code left in the script.

Consider removing or making it conditional:

-        print(config_data)
+        # Optionally log config for debugging
+        # print(config_data)

98-102: Improve error message clarity.

The error message "Cannot export model to the model_config" is unclear. Consider clarifying what failed and providing actionable guidance.

         warnings.warn(
-            "Cannot export model to the model_config. The modelopt-optimized model state_dict"
-            " can be saved with torch.save for further inspection."
+            f"Failed to export model to {export_dir}. The modelopt-optimized model state_dict "
+            "can be saved with torch.save for further inspection."
         )

29-29: Remove unused RAND_SEED constant.

The RAND_SEED constant is defined but never used in the script.

-RAND_SEED = 1234
-

32-35: Add type hints to function signatures.

The function lacks type hints, which would improve code clarity and enable static type checking. Consider adding return type annotation.

-def get_lora_model(
-    ckpt_path: str,
-    device="cuda",
-):
+def get_lora_model(
+    ckpt_path: str,
+    device: str = "cuda",
+) -> torch.nn.Module:

59-59: Add type hints to main function.

Consider adding type hints for better code clarity.

-def main(args):
+def main(args: argparse.Namespace) -> None:
modelopt/torch/export/quant_utils.py (2)

903-907: Improve LoRA key detection to avoid false positives.

The substring check "lora" in key on line 906 may inadvertently match keys that contain "lora" as part of a larger word (e.g., "flora", "explorer") or match model-native LoRA adapters that should be preserved.

Based on learnings

Apply this diff to use more precise path-segment matching:

     # remove LoRA adapters from state dict
     if is_modelopt_qlora:
-        for key in post_state_dict:
-            if "lora" in key and key not in keys_to_delete:
+        for key in list(post_state_dict.keys()):
+            parts = key.split(".")
+            # Check if "lora" appears as a complete path segment or as a prefix (e.g., lora_A, lora_B)
+            if (("lora" in parts or any(p.startswith("lora_") for p in parts)) 
+                and key not in keys_to_delete):
                 keys_to_delete.append(key)

Note: Also changed to iterate over list(post_state_dict.keys()) to avoid issues with dictionary modification during iteration.


1086-1096: Improvement: NVFP4-specific block_size handling partially addresses past concerns.

The new logic correctly skips layer config entries for NVFP4 formats when block_size == 0, which indicates the weight_quantizer is not enabled. This is an improvement over the previous approach that skipped ALL formats with block_size == 0.

However, awq_block_size is still written for all formats on line 1096, even those that don't use block quantization (e.g., QUANTIZATION_INT8_SQ, QUANTIZATION_FP8, QUANTIZATION_FP8_PC_PT).

Consider only writing awq_block_size for formats that actually use it:

             # Construct per layer config dictionary
             layer_config_dict[name + ".quantization"] = quantization_format
-            layer_config_dict[name + ".awq_block_size"] = block_size
+            # Only write block_size for block-quantized formats
+            if block_size > 0:
+                layer_config_dict[name + ".awq_block_size"] = block_size

This avoids polluting the config with unnecessary zero values for per-tensor formats, though the process_layer_quant_config function (line 601) already filters these out with if "awq_block_size" in k: continue.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7b7188e and 1959178.

📒 Files selected for processing (5)
  • examples/llm_qat/README.md (1 hunks)
  • examples/llm_qat/export.py (1 hunks)
  • modelopt/torch/export/quant_utils.py (4 hunks)
  • modelopt/torch/export/unified_export_hf.py (3 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/unified_export_hf.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (336-495)
modelopt/torch/opt/conversion.py (2)
  • restore_from_modelopt_state (510-567)
  • modelopt_state (444-486)
modelopt/torch/quantization/utils.py (1)
  • set_quantizer_state_dict (459-466)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
modelopt/torch/opt/conversion.py (2)
  • modelopt_state (444-486)
  • save (489-507)
modelopt/torch/quantization/utils.py (1)
  • get_quantizer_state_dict (446-456)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/export/unified_export_megatron.py (1)
  • state_dict (465-469)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • maxbound (188-194)
🔇 Additional comments (5)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

212-219: LGTM! Pre-compression state capture enables correct export.

Saving the modelopt state and quantizer weights before compression is essential for the export workflow. The export.py script (line 47) loads this exact file to restore the quantizer configuration needed for proper model export. The timing (post-calibration, pre-compression) ensures the quantization metadata is preserved for deployment.


287-297: Ignore attribute name inconsistency suggestion. args.lora is the boolean toggle and args.lora_config holds the adapter config; both are used correctly.

Likely an incorrect or invalid review comment.


149-149: Verify default adapter naming in add_adapter calls.

  • Confirm that self.model.add_adapter(self.args.lora_config) (lines 149 & 361) assigns a default adapter_name that self.model.active_adapter() (line 293) returns, so delete_adapter and load_adapter operate correctly.
  • If not, pass an explicit adapter_name to add_adapter to guarantee consistent lifecycle management.
examples/llm_qat/export.py (1)

71-76: Clarify which quantization config format is saved to hf_quant_config.json.

Line 73-74 saves the original hf_quant_config (modelopt format), while line 76 converts it to llm-compressor format and embeds it in config.json (line 89). This means two different formats are persisted:

  • base_model/hf_quant_config.json: modelopt format
  • base_model/config.json (quantization_config field): llm-compressor format

If both formats are needed for different consumers, consider adding a comment explaining why both are saved. Otherwise, consider saving only the converted format.

modelopt/torch/export/quant_utils.py (1)

835-846: Potential logic issue: skip_keys may prevent replacements from being applied.

On line 846, "base_layer" is appended to skip_keys. Then on line 852, the code skips any key where all(skip_key not in key for skip_key in skip_keys) is true. This means keys containing "base_layer" will be skipped entirely and never reach the replacement logic (lines 857-882).

However, the replacements dictionary (lines 839-845) includes patterns like "base_layer.weight" which should be transformed to "weight". These replacements won't be applied because the keys are filtered out first.

Consider refactoring to apply replacements first, then skip keys that should be removed:

-    skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"]
+    # Keys to skip entirely (not related to quantizers or base transformations)
+    skip_patterns = ["output_quantizer"]
 
     # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
     if is_modelopt_qlora:
         replacements.update(
             {
                 "base_layer.weight": "weight",
                 "base_layer.input_scale": "input_scale",
                 "base_layer.weight_scale": "weight_scale",
             }
         )
-        skip_keys.append("base_layer")
 
     post_state_dict = {}
 
     for key, value in state_dict.items():
-        # Skip keys not related to quantizers
-        if all(skip_key not in key for skip_key in skip_keys):
-            post_state_dict[key] = value
-            continue
+        # Skip keys that should be entirely filtered out
+        if any(skip_pattern in key for skip_pattern in skip_patterns):
+            continue
 
-        # Apply replacements if the key matches any suffix in the replacements dict
+        # Try to apply replacements first
+        replaced = False
         for old_suffix, new_suffix in replacements.items():
             if key.endswith(old_suffix):
                 prefix = key[: -len(old_suffix)]
 
                 if "_amax" in key:
                     # ... existing _amax handling ...
 
                 post_state_dict[prefix + new_suffix] = value
+                replaced = True
                 break
+        
+        # If no replacement was applied and key doesn't contain quantizer-specific suffixes, keep it
+        if not replaced:
+            quantizer_suffixes = ["_amax", "_bias_value", "input_quantizer._pre_quant_scale"]
+            if not any(key.endswith(suffix) for suffix in quantizer_suffixes):
+                post_state_dict[key] = value

Likely an incorrect or invalid review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
examples/llm_qat/export.py (4)

29-29: Remove unused constant.

RAND_SEED is defined but never used in the script.

-RAND_SEED = 1234
-

32-56: Consider improving robustness and observability.

The function lacks error handling for missing checkpoint files and uses print instead of logging. Consider:

  1. Error handling: Wrap file loading in try-except to provide clear error messages if modelopt_state_calibration.pth is missing or corrupted.
  2. Logging: Replace print("Restoring modelopt weights") with proper logging (e.g., logging.info(...)).
  3. Documentation: Expand the docstring to document parameters, return value, and expected checkpoint structure.

Example improvements:

+import logging
+
 def get_lora_model(
     ckpt_path: str,
     device="cuda",
 ):
     """
-    Loads a QLoRA model that has been trained using modelopt trainer.
+    Loads a QLoRA model that has been trained using modelopt trainer.
+    
+    Args:
+        ckpt_path: Path to the checkpoint directory containing the model and modelopt state.
+        device: Device to load the model on ("cuda" or "cpu").
+    
+    Returns:
+        The loaded model with restored modelopt and quantizer state.
     """
     device_map = "auto"
     if device == "cpu":
         device_map = "cpu"
 
     # Load model with adapters
     model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
 
     # Restore modelopt state
+    try:
-        modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False)
+            modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False)
+    except FileNotFoundError as e:
+        raise FileNotFoundError(
+            f"modelopt_state_calibration.pth not found in {ckpt_path}. "
+            "Ensure the checkpoint was saved correctly during training."
+        ) from e
     restore_from_modelopt_state(model, modelopt_state)
 
     # Restore modelopt quantizer state dict
     modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
     if modelopt_weights is not None:
-        print("Restoring modelopt weights")
+        logging.info("Restoring modelopt quantizer weights")
         set_quantizer_state_dict(model, modelopt_weights)
 
     return model

59-101: Use Path objects consistently and improve error messages.

The function mixes Path objects with f-string concatenation. For consistency and robustness, use Path objects throughout. Also, the exception handler's warning message is generic and doesn't help users diagnose the issue.

Apply this diff:

 def main(args):
     # Load model
     model = get_lora_model(args.pyt_ckpt_path, args.device)
     tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
 
     # Export HF checkpoint
     export_dir = Path(args.export_path)
     export_dir.mkdir(parents=True, exist_ok=True)
     base_model_dir = export_dir / "base_model"
     base_model_dir.mkdir(parents=True, exist_ok=True)
 
     try:
         post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True)
 
-        with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file:
+        with open(base_model_dir / "hf_quant_config.json", "w") as file:
             json.dump(hf_quant_config, file, indent=4)
 
         hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
 
         # Save base model
-        model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict)
+        model.base_model.save_pretrained(base_model_dir, state_dict=post_state_dict)
         # Save adapters
         model.save_pretrained(export_dir)
 
-        config_path = f"{export_dir}/base_model/config.json"
+        config_path = base_model_dir / "config.json"
 
         # In the case of LoRA model.save_pretrained does not save the correct config.json
         config_data = model.config.to_dict()
 
         config_data["quantization_config"] = hf_quant_config
 
         with open(config_path, "w") as file:
             json.dump(config_data, file, indent=4)
 
         # Save tokenizer
         tokenizer.save_pretrained(export_dir)
 
     except Exception as e:
         warnings.warn(
-            "Cannot export model to the model_config. The modelopt-optimized model state_dict"
-            " can be saved with torch.save for further inspection."
+            f"Export failed: {e}. The modelopt-optimized model state_dict "
+            "can be saved with torch.save for further inspection."
         )
         raise e

85-86: Clarify the comment about config.json handling.

The comment is grammatically incomplete. Consider rephrasing for clarity.

-        # In the case of LoRA model.save_pretrained does not save the correct config.json
+        # For LoRA models, save_pretrained does not include quantization_config in config.json,
+        # so we manually reconstruct and write it here.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1959178 and 0784fe8.

📒 Files selected for processing (1)
  • examples/llm_qat/export.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (336-495)
modelopt/torch/opt/conversion.py (2)
  • restore_from_modelopt_state (510-567)
  • modelopt_state (444-486)
modelopt/torch/quantization/utils.py (1)
  • set_quantizer_state_dict (459-466)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
examples/llm_qat/export.py (2)

104-118: LGTM!

The argument parser is well-structured with appropriate defaults and a required checkpoint path.


1-118: Acknowledge planned refactor per PR discussion.

Based on the PR objectives, this script is temporary. The author has agreed to refactor the export logic to integrate with hf_ptq.py per reviewer feedback. This standalone script serves the immediate QLoRA export use case but should be consolidated with the existing export flow in a follow-up.

Based on PR objectives summary.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/llm_qat/README.md (1)

368-370: Consider clarifying the adapter name convention.

The vLLM command uses --lora-modules adapter=llama3-fp4-qlora-hf where "adapter" is a generic name. It would be helpful to clarify whether this name is:

  • A fixed convention expected by vLLM
  • User-configurable and should match a specific naming pattern
  • Related to the export output structure

Adding a brief note about the adapter naming would improve usability, especially given the PR discussion about simplifying the CLI for QLoRA checkpoints.

modelopt/torch/export/quant_utils.py (1)

904-908: Substring matching may catch unintended keys.

The LoRA adapter removal logic uses substring matching ("lora" in key), which could inadvertently match keys like "flora", "exploration", or "coloration". While the is_modelopt_qlora gate reduces this risk, consider more precise matching.

Apply this diff for more precise LoRA key detection:

     # remove LoRA adapters from state dict
     if is_modelopt_qlora:
         for key in post_state_dict:
-            if "lora" in key and key not in keys_to_delete:
+            # Match LoRA keys more precisely: check for .lora. or lora_ patterns
+            parts = key.split(".")
+            if (any(p.startswith("lora_") or p == "lora" for p in parts)) and key not in keys_to_delete:
                 keys_to_delete.append(key)

This ensures matching only keys with "lora" as a complete path component or prefix (e.g., "model.lora_A.weight", "adapter.lora.bias"), not arbitrary substrings.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0784fe8 and dc7ee10.

📒 Files selected for processing (3)
  • examples/llm_qat/README.md (1 hunks)
  • examples/llm_qat/export.py (1 hunks)
  • modelopt/torch/export/quant_utils.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_qat/export.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (4)
modelopt/torch/export/unified_export_megatron.py (1)
  • state_dict (465-469)
modelopt/torch/opt/conversion.py (1)
  • state_dict (130-132)
modelopt/torch/distill/distillation_model.py (1)
  • state_dict (189-192)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • maxbound (188-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (6)
examples/llm_qat/README.md (1)

357-370: Documentation provides clear QLoRA deployment flow.

The new export and deployment instructions successfully replace the previous experimental note with actionable guidance. The command examples are well-structured and the reference to vLLM documentation is helpful.

modelopt/torch/export/quant_utils.py (5)

812-827: LGTM: Parameter documented.

The new is_modelopt_qlora parameter has been properly documented in the docstring at line 824. The documentation is clear and consistent with the existing style.


836-836: LGTM: Centralized skip logic.

The skip_keys list provides a cleaner, more maintainable approach to filtering non-quantizer keys. This addresses the concern from previous reviews about structured key handling.


838-847: LGTM: LoRA handling properly gated.

The LoRA-specific key transformations are now correctly gated behind the is_modelopt_qlora flag. This addresses the previous concern about unconditional LoRA adapter removal and ensures the logic only applies to ModelOpt QLoRA exports.


1087-1093: LGTM: NVFP4 block_size filtering fixed.

The guard correctly skips only NVFP4-related formats when block_size == 0, indicating the weight quantizer is not enabled. This fixes the previous issue where non-block formats (INT8_SQ, FP8, FP8_PC_PT) were incorrectly dropped. Non-block formats now proceed to lines 1095-1097 regardless of block_size.


851-855: postprocess_state_dict filtering logic is correct
Non-quantizer keys are retained, and all quantizer/base_layer keys are either transformed per the replacements mapping or removed as intended.

block_size = get_weight_block_size(module)

# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled
if block_size == 0 and quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a better flag instead of checking the block_size? E.g. weight_quantizer enabled vs disabled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check and update the PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 212 to 220
# Save modelopt state before compression. This is used to later export the model for deployment.
modelopt_state = mto.modelopt_state(self.model)
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth")

print_rank_0(
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}"
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this below and save modelopt state only before compression?

# TODO: Remove once we migrate to using get_peft_model()
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, does this only load the lora adapter? Maybe we need to add a check to make sure the base model is expected to be frozen/compressed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this loads only the best lora adapter. The reason I introduced this logic is our current workflow seems slightly incompatible with HF trainer as we are using .add_adapter() instead of get_peft_model() due to which HF trainer doesn't detect it as a peft model. This is causing some errors in the final load_best_checkpoint() call so I added this temporary fix until we migrate to using get_peft_model().

Currently I execute this logic only if compress is enabled and fsdp2 is not enabled (indicating DDP QLoRA). Do you recommend any other checks?

Comment on lines 74 to 96
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True)

with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

# Save base model
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
# Save adapters
model.save_pretrained(export_dir)

config_path = f"{base_model_dir}/config.json"

config_data = model.config.to_dict()

config_data["quantization_config"] = hf_quant_config

with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)

# Save tokenizer
tokenizer.save_pretrained(export_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make more sense to add this to utils as a function (or add the qlora specific logic as part of export_hf_checkpoint so it can be used by other scripts easily) which can be used to export at the end of training. The main file can provide option to export (optional) instead of creating a new export.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move the saving logic to export_hf_checkpoint() and simplify the script. The reason I have a separate script is so that we don't tie the export to the main training script. This could be useful in the case where the user has a couple of saved checkpoints and later wants to try exporting with different deployment options (merged adapters in bf16, merged adapters with quantization or quantized base model with bf16 adapters).

I also wanted to avoid going through hf_ptq.py because it would add more complexity to our existing hf_ptq.py script (user having to specify an additional lora flag, potentially having to add more flags to specify available deployment options for QLoRA, either user having to keep track and specify qformat of base model/ additional logic to save that information and infer it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kinjalpatel27 do you see another example can benefit from this export script as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it makes sense to avoid using hf_ptq.py for this export.
I think simplifying the script and having it as a function might be better if a user want to create their own script to train and later export for deployment. I don't have another example script in mind atm.

Comment on lines 212 to 220
# Save modelopt state before compression. This is used to later export the model for deployment.
modelopt_state = mto.modelopt_state(self.model)
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth")

print_rank_0(
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}"
)

Copy link
Contributor

@realAsma realAsma Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self._save_modelopt_state_with_weights() instead of this?

This save here does not correctly handle distributed training. We should not save from all ranks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! Will update.

self.model.add_adapter(self.args.lora_config)
print_rank_0("Lora adapter added.")

if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need disable_lora_quantizers_in_config? This does not seem warranted to me.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
examples/llm_qat/export.py (2)

71-71: Make QLoRA detection more explicit.

The condition is_qlora = hasattr(model, "peft_config") assumes any model with PEFT config is a QLoRA model. This may incorrectly identify regular LoRA (without quantization) as QLoRA. Consider checking for both PEFT config and quantization state:

-    is_qlora = hasattr(model, "peft_config")
+    from modelopt.torch.quantization.utils import is_quantized
+    is_qlora = hasattr(model, "peft_config") and is_quantized(model)

109-114: Make exception handling more specific and informative.

The broad except Exception catches all exceptions with a generic warning message. This can hide specific failures (I/O errors, permission issues, serialization errors) and makes debugging harder.

Consider handling specific exception types:

-    except Exception as e:
-        warnings.warn(
-            "Cannot export model to the model_config. The modelopt-optimized model state_dict"
-            " can be saved with torch.save for further inspection."
-        )
-        raise e
+    except (OSError, IOError) as e:
+        warnings.warn(f"Failed to write export files to {base_model_dir}: {e}")
+        raise
+    except Exception as e:
+        warnings.warn(
+            f"Export failed: {e}. The modelopt-optimized model state_dict "
+            "can be saved with torch.save for further inspection."
+        )
+        raise
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc7ee10 and 7e614ad.

📒 Files selected for processing (2)
  • examples/llm_qat/export.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_qat/export.py (6)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (336-495)
modelopt/torch/opt/conversion.py (2)
  • restore_from_modelopt_state (510-567)
  • modelopt_state (444-486)
modelopt/torch/quantization/utils.py (1)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)

212-213: Confirm state save timing aligns with export requirements.

The ModelOpt state is now saved immediately after quantization and before compression (line 217). A previous reviewer (meenchen) questioned whether saving should occur only before compression. Please confirm this placement correctly captures the calibrated quantizer state needed for the export workflow, especially considering that compression modifies weights.


149-149: Adapter_name removal is safe: searches found no hardcoded “adapter” strings or API calls expecting a literal adapter name.

Comment on lines +53 to +62
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")

# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Verify the state restoration order matches the trainer's pattern.

The sequence here differs from the trainer's _restore_modelopt_state_with_weights method (transformers_trainer.py lines 184-190):

Trainer pattern:

  1. Load state dict
  2. Pop modelopt_state_weights
  3. Call restore_from_modelopt_state
  4. Set quantizer state dict

Current export.py pattern:

  1. Load state dict
  2. Call restore_from_modelopt_state (line 55)
  3. Pop modelopt_state_weights (line 59)
  4. Set quantizer state dict

The pop operation at line 59 occurs after restore_from_modelopt_state, which may cause the method to receive and process the weights that should be handled separately. Align the order with the trainer's proven pattern.

Apply this diff:

     # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
     if hasattr(model, "peft_config"):
         modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         restore_from_modelopt_state(model, modelopt_state)
         print_rank_0("Restored modelopt state")
 
         # Restore modelopt quantizer state dict
-        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         if modelopt_weights is not None:
             set_quantizer_state_dict(model, modelopt_weights)
             print_rank_0("Restored modelopt quantizer state dict")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 53 to 62, the restore order differs
from the trainer: pop modelopt_state_weights before calling
restore_from_modelopt_state so the weights are removed from the state passed
into restore; specifically, after loading modelopt_state, call
modelopt_state.pop("modelopt_state_weights", None) and keep the popped value in
modelopt_weights, then call restore_from_modelopt_state(model, modelopt_state),
and finally if modelopt_weights is not None call set_quantizer_state_dict(model,
modelopt_weights) and print the restored messages.

Comment on lines +280 to +290
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
if is_lora and not self.is_fsdp_enabled:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
else:
super()._load_best_model(*args, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling and validate checkpoint existence.

The custom LoRA loading logic lacks safety checks:

  1. No validation that self.state.best_model_checkpoint exists before attempting to load
  2. No error handling if load_adapter fails
  3. The condition checks is_lora and not self.is_fsdp_enabled, but an earlier comment thread mentioned this should only execute "if compress is enabled and fsdp2 is not enabled." The current logic doesn't check the compress flag.

Add defensive checks:

 def _load_best_model(self, *args, **kwargs):
     """Load the best model for final evaluation."""
     is_lora = getattr(self.args, "lora", None)
-    if is_lora and not self.is_fsdp_enabled:
+    is_compressed = getattr(self.quant_args, "compress", False)
+    if is_lora and not self.is_fsdp_enabled and is_compressed:
         # Custom logic for loading best model with LoRA
         # TODO: Remove once we migrate to using get_peft_model()
+        if not self.state.best_model_checkpoint:
+            print_rank_0("No best model checkpoint found, skipping adapter reload")
+            return
+        try:
-            adapter_name = self.model.active_adapter()
-            self.model.delete_adapter(adapter_name)
-            self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+            adapter_name = self.model.active_adapter()
+            self.model.delete_adapter(adapter_name)
+            self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+            print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
+        except Exception as e:
+            print_rank_0(f"Failed to load best adapter: {e}")
+            raise
     else:
         super()._load_best_model(*args, **kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
if is_lora and not self.is_fsdp_enabled:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
else:
super()._load_best_model(*args, **kwargs)
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
is_compressed = getattr(self.quant_args, "compress", False)
if is_lora and not self.is_fsdp_enabled and is_compressed:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
if not self.state.best_model_checkpoint:
print_rank_0("No best model checkpoint found, skipping adapter reload")
return
try:
adapter_name = self.model.active_adapter()
self.model.delete_adapter(adapter_name)
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
except Exception as e:
print_rank_0(f"Failed to load best adapter: {e}")
raise
else:
super()._load_best_model(*args, **kwargs)

Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
@sugunav14 sugunav14 force-pushed the svelury/qlora-ddp-export branch from 7e614ad to f5f91ab Compare October 6, 2025 06:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

280-290: Add error handling and validate checkpoint existence.

A past review comment noted that this custom LoRA loading logic lacks safety checks:

  1. No validation that self.state.best_model_checkpoint exists before attempting to load
  2. No error handling if load_adapter fails
  3. The condition doesn't check the compress flag, though the TODO mentions this should execute "if compress is enabled and fsdp2 is not enabled"

Apply this diff to add defensive checks:

 def _load_best_model(self, *args, **kwargs):
     """Load the best model for final evaluation."""
     is_lora = getattr(self.args, "lora", None)
-    if is_lora and not self.is_fsdp_enabled:
+    is_compressed = getattr(self.quant_args, "compress", False)
+    if is_lora and not self.is_fsdp_enabled and is_compressed:
         # Custom logic for loading best model with LoRA
         # TODO: Remove once we migrate to using get_peft_model()
+        if not self.state.best_model_checkpoint:
+            print_rank_0("No best model checkpoint found, skipping adapter reload")
+            return
+        try:
             adapter_name = self.model.active_adapter()
             self.model.delete_adapter(adapter_name)
             self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+            print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
+        except Exception as e:
+            print_rank_0(f"Failed to load best adapter: {e}")
+            raise
     else:
         super()._load_best_model(*args, **kwargs)
examples/llm_qat/export.py (1)

53-62: Fix the modelopt state restore order to match the trainer pattern.

A past review comment noted that the restore order differs from the trainer's _restore_modelopt_state_with_weights method. The current implementation pops modelopt_state_weights AFTER calling restore_from_modelopt_state, which may cause the method to receive and process weights that should be handled separately.

Trainer pattern (transformers_trainer.py lines 184-190):

  1. Load state dict
  2. Pop modelopt_state_weights
  3. Call restore_from_modelopt_state
  4. Set quantizer state dict

Current pattern:

  1. Load state dict (line 54)
  2. Call restore_from_modelopt_state (line 55)
  3. Pop modelopt_state_weights (line 59)
  4. Set quantizer state dict (line 61)

Apply this diff to align with the trainer's proven pattern:

     # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
     if hasattr(model, "peft_config"):
         modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         restore_from_modelopt_state(model, modelopt_state)
         print_rank_0("Restored modelopt state")
 
         # Restore modelopt quantizer state dict
-        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         if modelopt_weights is not None:
             set_quantizer_state_dict(model, modelopt_weights)
             print_rank_0("Restored modelopt quantizer state dict")

Based on learnings.

modelopt/torch/export/quant_utils.py (1)

910-914: Refine LoRA key matching to avoid false positives.

The substring check "lora" in key can match unintended keys containing "lora" as part of a longer word (e.g., "flora", "colorama", "explorer_adapter"). A past review comment raised this concern and suggested checking against split key parts.

Apply this diff to use more precise matching:

     # remove LoRA adapters from state dict
     if is_modelopt_qlora:
-        for key in post_state_dict:
-            if "lora" in key and key not in keys_to_delete:
-                keys_to_delete.append(key)
+        for key in list(post_state_dict.keys()):
+            parts = key.split(".")
+            if (any(part == "lora" or part.startswith("lora_") for part in parts)) and key not in keys_to_delete:
+                keys_to_delete.append(key)
🧹 Nitpick comments (1)
examples/llm_qat/export.py (1)

37-50: Add validation for checkpoint path and required files.

The get_lora_model function doesn't validate that the checkpoint path exists or contains the required modelopt_state_train.pth file before attempting to load, which could lead to confusing error messages.

Apply this diff to add defensive checks:

 def get_lora_model(
     ckpt_path: str,
     device="cuda",
 ):
     """
     Loads a QLoRA model that has been trained using modelopt trainer.
     """
+    # Validate checkpoint path
+    ckpt_path_obj = Path(ckpt_path)
+    if not ckpt_path_obj.exists():
+        raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
+    
     # TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment
     device_map = "auto"
     if device == "cpu":
         device_map = "cpu"
 
     # Load model
     model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
 
     # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
     if hasattr(model, "peft_config"):
+        modelopt_state_path = ckpt_path_obj / "modelopt_state_train.pth"
+        if not modelopt_state_path.exists():
+            raise FileNotFoundError(
+                f"Expected modelopt state file not found: {modelopt_state_path}. "
+                "This file should be created during training."
+            )
-        modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+        modelopt_state = torch.load(modelopt_state_path, weights_only=False)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7e614ad and f5f91ab.

📒 Files selected for processing (6)
  • examples/llm_qat/README.md (2 hunks)
  • examples/llm_qat/export.py (1 hunks)
  • modelopt/torch/export/quant_utils.py (4 hunks)
  • modelopt/torch/export/unified_export_hf.py (3 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (4 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/unified_export_hf.py
  • examples/llm_qat/README.md
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • examples/llm_qat/export.py
🧬 Code graph analysis (2)
examples/llm_qat/export.py (6)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (340-499)
modelopt/torch/opt/conversion.py (2)
  • restore_from_modelopt_state (510-567)
  • modelopt_state (444-486)
modelopt/torch/quantization/utils.py (1)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/export/unified_export_megatron.py (1)
  • state_dict (465-469)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • maxbound (193-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/export/quant_utils.py (3)

818-823: LGTM: Parameter addition and docstring are well-documented.

The new is_modelopt_qlora parameter is properly documented and has a sensible default value, maintaining backward compatibility.

Also applies to: 830-830


842-859: LGTM: Dynamic key filtering logic correctly handles modelopt QLora.

The conditional extension of replacements and skip_keys for modelopt QLora is cleanly implemented, and the updated skip condition using all() properly handles the dynamic skip list.


1093-1099: NVFP4 skip logic is correct
block_size can only be zero when get_weight_block_size returns 0 (no quantizer, disabled, or empty blocks), and get_quantization_format yields NVFP4, NVFP4_AWQ, or W4A8_NVFP4_FP8 only for enabled quantizers with non‐zero block sizes. No changes needed.

modelopt/torch/quantization/plugins/transformers_trainer.py (2)

212-213: LGTM: ModelOpt state saved after calibration.

Positioning the save call after quantization and before compression ensures the calibration state is captured correctly, addressing the concern from past reviews.


149-149: LGTM: Adapter name parameter removed.

Removing the explicit adapter_name parameter simplifies the API and suggests the adapter name is now handled automatically or uses a sensible default.

Also applies to: 354-354

examples/llm_qat/export.py (1)

67-115: LGTM: Export logic correctly handles QLora and standard models.

The main function properly detects QLora models, creates the appropriate directory structure (separate base_model directory for QLora), calls the export with the correct flag, and handles both model types in the save logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants