-
Notifications
You must be signed in to change notification settings - Fork 169
QLoRA DDP export #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
QLoRA DDP export #353
Conversation
WalkthroughAdds QLoRA-aware export and training-state handling: updates quant export utilities and HF unified export to handle LoRA-specific keys/quantizers, saves/restores ModelOpt calibration state, provides an export script for QLoRA models, updates docs for vLLM deployment, and re-enables a QLoRA NVFP4 test. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant ExportScript as examples/llm_qat/export.py
participant Model as QLoRA Model
participant CalibState as modelopt_state_calib.pth
participant Exporter as _export_hf_checkpoint
participant Utils as postprocess_state_dict
User->>ExportScript: run main(args)
ExportScript->>Model: load base model + LoRA adapters
ExportScript->>CalibState: load modelopt_state_calib.pth (if present)
CalibState-->>Model: restore calibration & optional quantizer weights
ExportScript->>Exporter: _export_hf_checkpoint(model, is_modelopt_qlora=True)
Exporter->>Utils: postprocess_state_dict(state_dict, maxbound, quant, is_modelopt_qlora=True)
Utils-->>Exporter: processed_state_dict + hf_quant_config
Exporter-->>ExportScript: return artifacts
ExportScript->>User: write model, adapters, quant config, tokenizer
sequenceDiagram
autonumber
participant Trainer as QATTrainer
participant Model
participant FS as Filesystem
participant Loader as _load_best_model
Trainer->>Model: calibrate quantizers
Trainer->>FS: save modelopt_state_calib.pth
alt load best model (no FSDP, LoRA present)
Trainer->>Loader: _load_best_model(...)
Loader-->>Trainer: custom best-model load (handle adapters)
else
Trainer->>Loader: delegate to superclass
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
035117f
to
6254cad
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #353 +/- ##
==========================================
- Coverage 73.79% 73.79% -0.01%
==========================================
Files 171 171
Lines 17591 17591
==========================================
- Hits 12982 12981 -1
- Misses 4609 4610 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
743-745
: Verify dtype casting behavior and potential precision loss.The dtype casting is applied unconditionally when a dtype is provided, which could lead to precision loss or unexpected behavior if the original weight dtype is more precise than the target dtype.
Consider adding a warning or validation:
if dtype: + if weight.dtype != dtype: + logger.info(f"Converting weight from {weight.dtype} to {dtype}") weight = weight.to(dtype)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/main.py
(1 hunks)modelopt/torch/export/quant_utils.py
(5 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint
(503-556)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/export/quant_utils.py (2)
maybe_transpose_expert_weight_dimensions
(91-120)to_quantized_weight
(724-790)modelopt/torch/export/layer_utils.py (1)
is_quantlinear
(346-348)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled
(389-391)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
examples/llm_qat/main.py (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
export_base_model
(291-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: linux
🔇 Additional comments (14)
modelopt/torch/export/quant_utils.py (4)
730-730
: LGTM! Optional dtype parameter added correctly.The addition of the optional
dtype
parameter follows good API design practices with a sensible default ofNone
.
833-836
: LGTM! Base layer key mappings added appropriately.The new base layer mappings support QLoRA export workflows by handling the transformation from LoRA-specific keys to standard model keys.
847-847
: LGTM! Proper exclusion of base_layer keys.The condition correctly filters out base_layer keys from the main processing loop, which is consistent with the new mappings approach.
899-903
: LGTM! LoRA adapter cleanup implemented correctly.The LoRA adapter removal logic is properly implemented to clean up adapter-specific parameters from the exported state dict, ensuring a clean base model export.
examples/llm_qat/main.py (1)
276-278
: LGTM! QLoRA export integration implemented correctly.The conditional call to
trainer.export_base_model()
is properly guarded by both the LoRA and compression flags, ensuring the base model is only exported when appropriate for QLoRA workflows.modelopt/torch/export/unified_export_hf.py (5)
88-91
: LGTM! Early return for LoRA models prevents unnecessary processing.The early return correctly skips processing for LoRA-finetuned models by detecting the presence of a
base_model
attribute, avoiding potential issues with the requantize/resmooth operations.
329-336
: Consistent dtype parameter usage.The non-NVFP4 quantization path correctly passes the dtype parameter to
to_quantized_weight
, maintaining consistency with the NVFP4 path above.
465-470
: Enhanced guard conditions for quantized weight export.The additional check for
hasattr(sub_module, "weight_quantizer")
andsub_module.weight_quantizer.is_enabled
provides better safety by ensuring quantizers exist and are active before attempting export.
531-536
: LGTM! Proper base model export for QLoRA models.The logic correctly identifies QLoRA models by checking for the
base_model
attribute and exports the underlying base model instead of the wrapper, which is essential for proper deployment.
317-323
: Resolved — internal dtype cast remains; no action needed.
quant_utils.py performsif dtype: weight = weight.to(dtype)
(modelopt/torch/export/quant_utils.py lines 743–744), so removing the pre-cast in unified_export_hf.py does not change quantization behavior.modelopt/torch/quantization/plugins/transformers_trainer.py (3)
31-31
: LGTM! Required import added.The import of
export_hf_checkpoint
is correctly added to support the new export functionality.
279-290
: LoRA-specific best model loading logic implemented correctly.The implementation correctly handles the difference between LoRA and non-LoRA models. For LoRA models, it properly removes and re-loads the adapter from the best checkpoint path.
Note: The TODO comment indicates this is temporary until
get_peft_model()
is used, which aligns with the PR description mentioning temporary fixes.
291-296
: Simple and effective base model export.The implementation correctly calls
export_hf_checkpoint
with the appropriate output directory structure, and the main process check ensures only one process performs the export.examples/llm_qat/README.md (1)
357-362
: Fix vLLM serve example: point --tokenizer to the base/merged model (keep --lora-modules syntax)
- Location: examples/llm_qat/README.md (lines 357–362).
- Replace the example so the served model is the merged or original base model and --tokenizer points to the base-model (or HF name). The current --lora-modules adapter=llama3-fp4-qlora is correct.
- Suggested command (use actual paths/names from your repo):
vllm serve --enable-lora --lora-modules adapter= --tokenizer --tokenizer-mode auto --trust-remote-code --port 8000- Verify tokenizer.pad_token_id == model.config.pad_token_id (or set pad_token = eos_token) to avoid generation/padding issues. Repo check found no base_model directory—confirm exact paths before committing this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/quant_utils.py (1)
732-747
: Guard the new dtype pre-cast to prevent fp8/integer pitfallsCasting to an arbitrary dtype before quantization can raise “Promotion for Float8 Types is not supported” or produce invalid math if an integer/bool dtype is passed. Restrict to safe float dtypes and reject fp8/ints.
Apply this diff:
def to_quantized_weight( weight: torch.Tensor, weights_scaling_factor: torch.Tensor, quantization: str, weights_scaling_factor2: torch.Tensor | None = None, block_size: int | None = None, dtype: torch.dtype | None = None, ): @@ - if dtype: - weight = weight.to(dtype) + if dtype is not None: + # Only allow >=16-bit float dtypes here; fp8 and non-floats break downstream ops. + allowed_dtypes = {torch.float32, torch.float16, torch.bfloat16} + disallowed = {getattr(torch, "float8_e4m3fn", None)} + if dtype in disallowed or dtype not in allowed_dtypes: + raise ValueError(f"Unsupported pre-quant cast dtype: {dtype}. Use float32/float16/bfloat16.") + weight = weight.to(dtype)
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
611-612
: Replace commented-out debug print with logger.debug or removeKeep logs consistent and avoid dead code.
Apply this diff:
- # print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}") + logger.debug("Processing layer %s with quantization=%s, block_size=%s", k, v, block_size_value)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/quant_utils.py
(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/export/quant_utils.py (1)
835-853
: Base-layer key remap: avoid silent drops and verify collisionsUnmapped base_layer.* keys are dropped by design. If other base attributes exist (e.g., bias), they’ll vanish silently, and remaps may overwrite existing top-level keys.
Please confirm:
- All required base_layer.* fields are covered by replacements.
- Remapped targets (e.g., “weight”, “input_scale”, “weight_scale”) won’t collide with already-present keys.
If needed, I can add a warning when a base_layer key is encountered but not remapped. Want a patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall implementation looks good. However we dont have to combine the export for QLoRA with transformer_trainer
. We should do the the export via hf_ptq.py
Okay, I will try refactoring the PR to do that! Thank you! |
@sugunav14 sounds good, for example here is how we support regular qat deployment - #353 (comment) I am thinking we should have something like:
For deployment of QAT and QLoRA checkpoint, we still need to specify
For QAT/QLoRA, can we support the following usage:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you test with Phi4-multimodal-instruct export (FP8 and NVFP4) and make sure the before and after the change the safetensors are the same?
To quant Phi4-multimodal-instruct, you need to :
- Download https://huggingface.co/microsoft/Phi-4-multimodal-instruct
- modify https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2101 and enforce using InputMode.LANGUAGE.
- use transformers 4.48
- Run FP8 and NVFP4 PTQ
- Compare the generated safetensors with https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP4 and https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP8. Make sure the tensor keys are the same.
modelopt/torch/export/quant_utils.py
Outdated
# Get the corresponding AWQ block size | ||
block_size_value = layer_config_dict.get(awq_key, 0) | ||
|
||
# print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
modelopt/torch/export/quant_utils.py
Outdated
if isinstance(weight, QTensorWrapper): | ||
return weight.data | ||
|
||
if dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a case where we need to cast the weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, seems like in the current unified export logic we do perform a cast before quantizing the weights
modelopt/torch/export/quant_utils.py
Outdated
keys_to_delete.append(key) | ||
|
||
# remove LoRA adapters from state dict | ||
for key, value in post_state_dict.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the original model has lora adapters like phi4-multimodal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not think of that case! Thanks for the catch!
modelopt/torch/export/quant_utils.py
Outdated
layer_config_dict[name + ".quantization"] = quantization_format | ||
layer_config_dict[name + ".awq_block_size"] = block_size | ||
# Handles case if default weight quantizer is not enabled or is None | ||
if block_size != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this impact per tensor quant like fp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will, just updated the condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (9)
examples/llm_qat/export.py (7)
32-56
: Add error handling for missing modelopt state file.Line 47 loads
modelopt_state_calibration.pth
without checking if it exists. If the file is missing (e.g., model was not trained with QLoRA or was trained with an older version),torch.load
will raise aFileNotFoundError
with an unclear error message.Consider adding explicit validation:
+from pathlib import Path + def get_lora_model( ckpt_path: str, device="cuda", ): """ Loads a QLoRA model that has been trained using modelopt trainer. """ + # Validate modelopt state file exists + modelopt_state_path = Path(ckpt_path) / "modelopt_state_calibration.pth" + if not modelopt_state_path.exists(): + raise FileNotFoundError( + f"modelopt_state_calibration.pth not found in {ckpt_path}. " + "Ensure the model was trained with QLoRA using the modelopt trainer." + ) + device_map = "auto" if device == "cpu": device_map = "cpu" # Load model with adapters model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) # Restore modelopt state - modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + modelopt_state = torch.load(modelopt_state_path, weights_only=False) restore_from_modelopt_state(model, modelopt_state)
53-53
: Consider using print_rank_0 for distributed consistency.The print statement on line 53 will execute on all ranks in a distributed setting, potentially causing log clutter. If this script may be used in a distributed context, consider using
print_rank_0
frommodelopt.torch.utils.logging
(already used in transformers_trainer.py).
87-87
: Remove or conditionalize debug print of config data.Line 87 prints the entire config_data dictionary, which can be very verbose. This appears to be debug code left in the script.
Consider removing or making it conditional:
- print(config_data) + # Optionally log config for debugging + # print(config_data)
98-102
: Improve error message clarity.The error message "Cannot export model to the model_config" is unclear. Consider clarifying what failed and providing actionable guidance.
warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." + f"Failed to export model to {export_dir}. The modelopt-optimized model state_dict " + "can be saved with torch.save for further inspection." )
29-29
: Remove unused RAND_SEED constant.The
RAND_SEED
constant is defined but never used in the script.-RAND_SEED = 1234 -
32-35
: Add type hints to function signatures.The function lacks type hints, which would improve code clarity and enable static type checking. Consider adding return type annotation.
-def get_lora_model( - ckpt_path: str, - device="cuda", -): +def get_lora_model( + ckpt_path: str, + device: str = "cuda", +) -> torch.nn.Module:
59-59
: Add type hints to main function.Consider adding type hints for better code clarity.
-def main(args): +def main(args: argparse.Namespace) -> None:modelopt/torch/export/quant_utils.py (2)
903-907
: Improve LoRA key detection to avoid false positives.The substring check
"lora" in key
on line 906 may inadvertently match keys that contain "lora" as part of a larger word (e.g., "flora", "explorer") or match model-native LoRA adapters that should be preserved.Based on learnings
Apply this diff to use more precise path-segment matching:
# remove LoRA adapters from state dict if is_modelopt_qlora: - for key in post_state_dict: - if "lora" in key and key not in keys_to_delete: + for key in list(post_state_dict.keys()): + parts = key.split(".") + # Check if "lora" appears as a complete path segment or as a prefix (e.g., lora_A, lora_B) + if (("lora" in parts or any(p.startswith("lora_") for p in parts)) + and key not in keys_to_delete): keys_to_delete.append(key)Note: Also changed to iterate over
list(post_state_dict.keys())
to avoid issues with dictionary modification during iteration.
1086-1096
: Improvement: NVFP4-specific block_size handling partially addresses past concerns.The new logic correctly skips layer config entries for NVFP4 formats when
block_size == 0
, which indicates the weight_quantizer is not enabled. This is an improvement over the previous approach that skipped ALL formats withblock_size == 0
.However,
awq_block_size
is still written for all formats on line 1096, even those that don't use block quantization (e.g.,QUANTIZATION_INT8_SQ
,QUANTIZATION_FP8
,QUANTIZATION_FP8_PC_PT
).Consider only writing
awq_block_size
for formats that actually use it:# Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format - layer_config_dict[name + ".awq_block_size"] = block_size + # Only write block_size for block-quantized formats + if block_size > 0: + layer_config_dict[name + ".awq_block_size"] = block_sizeThis avoids polluting the config with unnecessary zero values for per-tensor formats, though the
process_layer_quant_config
function (line 601) already filters these out withif "awq_block_size" in k: continue
.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/unified_export_hf.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(336-495)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
modelopt/torch/opt/conversion.py (2)
modelopt_state
(444-486)save
(489-507)modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/export/unified_export_megatron.py (1)
state_dict
(465-469)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
maxbound
(188-194)
🔇 Additional comments (5)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
212-219
: LGTM! Pre-compression state capture enables correct export.Saving the modelopt state and quantizer weights before compression is essential for the export workflow. The export.py script (line 47) loads this exact file to restore the quantizer configuration needed for proper model export. The timing (post-calibration, pre-compression) ensures the quantization metadata is preserved for deployment.
287-297
: Ignore attribute name inconsistency suggestion. args.lora is the boolean toggle and args.lora_config holds the adapter config; both are used correctly.Likely an incorrect or invalid review comment.
149-149
: Verify default adapter naming inadd_adapter
calls.
- Confirm that
self.model.add_adapter(self.args.lora_config)
(lines 149 & 361) assigns a defaultadapter_name
thatself.model.active_adapter()
(line 293) returns, sodelete_adapter
andload_adapter
operate correctly.- If not, pass an explicit
adapter_name
toadd_adapter
to guarantee consistent lifecycle management.examples/llm_qat/export.py (1)
71-76
: Clarify which quantization config format is saved to hf_quant_config.json.Line 73-74 saves the original
hf_quant_config
(modelopt format), while line 76 converts it to llm-compressor format and embeds it in config.json (line 89). This means two different formats are persisted:
base_model/hf_quant_config.json
: modelopt formatbase_model/config.json
(quantization_config field): llm-compressor formatIf both formats are needed for different consumers, consider adding a comment explaining why both are saved. Otherwise, consider saving only the converted format.
modelopt/torch/export/quant_utils.py (1)
835-846
: Potential logic issue: skip_keys may prevent replacements from being applied.On line 846,
"base_layer"
is appended toskip_keys
. Then on line 852, the code skips any key whereall(skip_key not in key for skip_key in skip_keys)
is true. This means keys containing"base_layer"
will be skipped entirely and never reach the replacement logic (lines 857-882).However, the replacements dictionary (lines 839-845) includes patterns like
"base_layer.weight"
which should be transformed to"weight"
. These replacements won't be applied because the keys are filtered out first.Consider refactoring to apply replacements first, then skip keys that should be removed:
- skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + # Keys to skip entirely (not related to quantizers or base transformations) + skip_patterns = ["output_quantizer"] # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment if is_modelopt_qlora: replacements.update( { "base_layer.weight": "weight", "base_layer.input_scale": "input_scale", "base_layer.weight_scale": "weight_scale", } ) - skip_keys.append("base_layer") post_state_dict = {} for key, value in state_dict.items(): - # Skip keys not related to quantizers - if all(skip_key not in key for skip_key in skip_keys): - post_state_dict[key] = value - continue + # Skip keys that should be entirely filtered out + if any(skip_pattern in key for skip_pattern in skip_patterns): + continue - # Apply replacements if the key matches any suffix in the replacements dict + # Try to apply replacements first + replaced = False for old_suffix, new_suffix in replacements.items(): if key.endswith(old_suffix): prefix = key[: -len(old_suffix)] if "_amax" in key: # ... existing _amax handling ... post_state_dict[prefix + new_suffix] = value + replaced = True break + + # If no replacement was applied and key doesn't contain quantizer-specific suffixes, keep it + if not replaced: + quantizer_suffixes = ["_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + if not any(key.endswith(suffix) for suffix in quantizer_suffixes): + post_state_dict[key] = valueLikely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/llm_qat/export.py (4)
29-29
: Remove unused constant.
RAND_SEED
is defined but never used in the script.-RAND_SEED = 1234 -
32-56
: Consider improving robustness and observability.The function lacks error handling for missing checkpoint files and uses print instead of logging. Consider:
- Error handling: Wrap file loading in try-except to provide clear error messages if
modelopt_state_calibration.pth
is missing or corrupted.- Logging: Replace
print("Restoring modelopt weights")
with proper logging (e.g.,logging.info(...)
).- Documentation: Expand the docstring to document parameters, return value, and expected checkpoint structure.
Example improvements:
+import logging + def get_lora_model( ckpt_path: str, device="cuda", ): """ - Loads a QLoRA model that has been trained using modelopt trainer. + Loads a QLoRA model that has been trained using modelopt trainer. + + Args: + ckpt_path: Path to the checkpoint directory containing the model and modelopt state. + device: Device to load the model on ("cuda" or "cpu"). + + Returns: + The loaded model with restored modelopt and quantizer state. """ device_map = "auto" if device == "cpu": device_map = "cpu" # Load model with adapters model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) # Restore modelopt state + try: - modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + except FileNotFoundError as e: + raise FileNotFoundError( + f"modelopt_state_calibration.pth not found in {ckpt_path}. " + "Ensure the checkpoint was saved correctly during training." + ) from e restore_from_modelopt_state(model, modelopt_state) # Restore modelopt quantizer state dict modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) if modelopt_weights is not None: - print("Restoring modelopt weights") + logging.info("Restoring modelopt quantizer weights") set_quantizer_state_dict(model, modelopt_weights) return model
59-101
: Use Path objects consistently and improve error messages.The function mixes Path objects with f-string concatenation. For consistency and robustness, use Path objects throughout. Also, the exception handler's warning message is generic and doesn't help users diagnose the issue.
Apply this diff:
def main(args): # Load model model = get_lora_model(args.pyt_ckpt_path, args.device) tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) # Export HF checkpoint export_dir = Path(args.export_path) export_dir.mkdir(parents=True, exist_ok=True) base_model_dir = export_dir / "base_model" base_model_dir.mkdir(parents=True, exist_ok=True) try: post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True) - with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file: + with open(base_model_dir / "hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) # Save base model - model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict) + model.base_model.save_pretrained(base_model_dir, state_dict=post_state_dict) # Save adapters model.save_pretrained(export_dir) - config_path = f"{export_dir}/base_model/config.json" + config_path = base_model_dir / "config.json" # In the case of LoRA model.save_pretrained does not save the correct config.json config_data = model.config.to_dict() config_data["quantization_config"] = hf_quant_config with open(config_path, "w") as file: json.dump(config_data, file, indent=4) # Save tokenizer tokenizer.save_pretrained(export_dir) except Exception as e: warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." + f"Export failed: {e}. The modelopt-optimized model state_dict " + "can be saved with torch.save for further inspection." ) raise e
85-86
: Clarify the comment about config.json handling.The comment is grammatically incomplete. Consider rephrasing for clarity.
- # In the case of LoRA model.save_pretrained does not save the correct config.json + # For LoRA models, save_pretrained does not include quantization_config in config.json, + # so we manually reconstruct and write it here.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/llm_qat/export.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(336-495)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
examples/llm_qat/export.py (2)
104-118
: LGTM!The argument parser is well-structured with appropriate defaults and a required checkpoint path.
1-118
: Acknowledge planned refactor per PR discussion.Based on the PR objectives, this script is temporary. The author has agreed to refactor the export logic to integrate with
hf_ptq.py
per reviewer feedback. This standalone script serves the immediate QLoRA export use case but should be consolidated with the existing export flow in a follow-up.Based on PR objectives summary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/llm_qat/README.md (1)
368-370
: Consider clarifying the adapter name convention.The vLLM command uses
--lora-modules adapter=llama3-fp4-qlora-hf
where "adapter" is a generic name. It would be helpful to clarify whether this name is:
- A fixed convention expected by vLLM
- User-configurable and should match a specific naming pattern
- Related to the export output structure
Adding a brief note about the adapter naming would improve usability, especially given the PR discussion about simplifying the CLI for QLoRA checkpoints.
modelopt/torch/export/quant_utils.py (1)
904-908
: Substring matching may catch unintended keys.The LoRA adapter removal logic uses substring matching (
"lora" in key
), which could inadvertently match keys like "flora", "exploration", or "coloration". While theis_modelopt_qlora
gate reduces this risk, consider more precise matching.Apply this diff for more precise LoRA key detection:
# remove LoRA adapters from state dict if is_modelopt_qlora: for key in post_state_dict: - if "lora" in key and key not in keys_to_delete: + # Match LoRA keys more precisely: check for .lora. or lora_ patterns + parts = key.split(".") + if (any(p.startswith("lora_") or p == "lora" for p in parts)) and key not in keys_to_delete: keys_to_delete.append(key)This ensures matching only keys with "lora" as a complete path component or prefix (e.g., "model.lora_A.weight", "adapter.lora.bias"), not arbitrary substrings.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_qat/export.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (4)
modelopt/torch/export/unified_export_megatron.py (1)
state_dict
(465-469)modelopt/torch/opt/conversion.py (1)
state_dict
(130-132)modelopt/torch/distill/distillation_model.py (1)
state_dict
(189-192)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
maxbound
(188-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (6)
examples/llm_qat/README.md (1)
357-370
: Documentation provides clear QLoRA deployment flow.The new export and deployment instructions successfully replace the previous experimental note with actionable guidance. The command examples are well-structured and the reference to vLLM documentation is helpful.
modelopt/torch/export/quant_utils.py (5)
812-827
: LGTM: Parameter documented.The new
is_modelopt_qlora
parameter has been properly documented in the docstring at line 824. The documentation is clear and consistent with the existing style.
836-836
: LGTM: Centralized skip logic.The
skip_keys
list provides a cleaner, more maintainable approach to filtering non-quantizer keys. This addresses the concern from previous reviews about structured key handling.
838-847
: LGTM: LoRA handling properly gated.The LoRA-specific key transformations are now correctly gated behind the
is_modelopt_qlora
flag. This addresses the previous concern about unconditional LoRA adapter removal and ensures the logic only applies to ModelOpt QLoRA exports.
1087-1093
: LGTM: NVFP4 block_size filtering fixed.The guard correctly skips only NVFP4-related formats when
block_size == 0
, indicating the weight quantizer is not enabled. This fixes the previous issue where non-block formats (INT8_SQ, FP8, FP8_PC_PT) were incorrectly dropped. Non-block formats now proceed to lines 1095-1097 regardless of block_size.
851-855
: postprocess_state_dict filtering logic is correct
Non-quantizer keys are retained, and all quantizer/base_layer keys are either transformed per the replacements mapping or removed as intended.
block_size = get_weight_block_size(module) | ||
|
||
# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled | ||
if block_size == 0 and quantization_format in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a better flag instead of checking the block_size? E.g. weight_quantizer enabled vs disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me check and update the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
# Save modelopt state before compression. This is used to later export the model for deployment. | ||
modelopt_state = mto.modelopt_state(self.model) | ||
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model) | ||
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth") | ||
|
||
print_rank_0( | ||
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this below and save modelopt state only before compression?
# TODO: Remove once we migrate to using get_peft_model() | ||
adapter_name = self.model.active_adapter() | ||
self.model.delete_adapter(adapter_name) | ||
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, does this only load the lora adapter? Maybe we need to add a check to make sure the base model is expected to be frozen/compressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this loads only the best lora adapter. The reason I introduced this logic is our current workflow seems slightly incompatible with HF trainer as we are using .add_adapter() instead of get_peft_model() due to which HF trainer doesn't detect it as a peft model. This is causing some errors in the final load_best_checkpoint() call so I added this temporary fix until we migrate to using get_peft_model().
Currently I execute this logic only if compress is enabled and fsdp2 is not enabled (indicating DDP QLoRA). Do you recommend any other checks?
examples/llm_qat/export.py
Outdated
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True) | ||
|
||
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file: | ||
json.dump(hf_quant_config, file, indent=4) | ||
|
||
hf_quant_config = convert_hf_quant_config_format(hf_quant_config) | ||
|
||
# Save base model | ||
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict) | ||
# Save adapters | ||
model.save_pretrained(export_dir) | ||
|
||
config_path = f"{base_model_dir}/config.json" | ||
|
||
config_data = model.config.to_dict() | ||
|
||
config_data["quantization_config"] = hf_quant_config | ||
|
||
with open(config_path, "w") as file: | ||
json.dump(config_data, file, indent=4) | ||
|
||
# Save tokenizer | ||
tokenizer.save_pretrained(export_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make more sense to add this to utils as a function (or add the qlora specific logic as part of export_hf_checkpoint so it can be used by other scripts easily) which can be used to export at the end of training. The main file can provide option to export (optional) instead of creating a new export.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could move the saving logic to export_hf_checkpoint() and simplify the script. The reason I have a separate script is so that we don't tie the export to the main training script. This could be useful in the case where the user has a couple of saved checkpoints and later wants to try exporting with different deployment options (merged adapters in bf16, merged adapters with quantization or quantized base model with bf16 adapters).
I also wanted to avoid going through hf_ptq.py because it would add more complexity to our existing hf_ptq.py script (user having to specify an additional lora flag, potentially having to add more flags to specify available deployment options for QLoRA, either user having to keep track and specify qformat of base model/ additional logic to save that information and infer it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kinjalpatel27 do you see another example can benefit from this export script as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it makes sense to avoid using hf_ptq.py for this export.
I think simplifying the script and having it as a function might be better if a user want to create their own script to train and later export for deployment. I don't have another example script in mind atm.
# Save modelopt state before compression. This is used to later export the model for deployment. | ||
modelopt_state = mto.modelopt_state(self.model) | ||
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model) | ||
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth") | ||
|
||
print_rank_0( | ||
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use self._save_modelopt_state_with_weights() instead of this?
This save here does not correctly handle distributed training. We should not save from all ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! Will update.
self.model.add_adapter(self.args.lora_config) | ||
print_rank_0("Lora adapter added.") | ||
|
||
if hasattr(self.model, "peft_config") and self.quant_cfg is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need disable_lora_quantizers_in_config
? This does not seem warranted to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
examples/llm_qat/export.py (2)
71-71
: Make QLoRA detection more explicit.The condition
is_qlora = hasattr(model, "peft_config")
assumes any model with PEFT config is a QLoRA model. This may incorrectly identify regular LoRA (without quantization) as QLoRA. Consider checking for both PEFT config and quantization state:- is_qlora = hasattr(model, "peft_config") + from modelopt.torch.quantization.utils import is_quantized + is_qlora = hasattr(model, "peft_config") and is_quantized(model)
109-114
: Make exception handling more specific and informative.The broad
except Exception
catches all exceptions with a generic warning message. This can hide specific failures (I/O errors, permission issues, serialization errors) and makes debugging harder.Consider handling specific exception types:
- except Exception as e: - warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." - ) - raise e + except (OSError, IOError) as e: + warnings.warn(f"Failed to write export files to {base_model_dir}: {e}") + raise + except Exception as e: + warnings.warn( + f"Export failed: {e}. The modelopt-optimized model state_dict " + "can be saved with torch.save for further inspection." + ) + raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/llm_qat/export.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_qat/export.py (6)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(336-495)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
212-213
: Confirm state save timing aligns with export requirements.The ModelOpt state is now saved immediately after quantization and before compression (line 217). A previous reviewer (meenchen) questioned whether saving should occur only before compression. Please confirm this placement correctly captures the calibrated quantizer state needed for the export workflow, especially considering that compression modifies weights.
149-149
: Adapter_name removal is safe: searches found no hardcoded “adapter” strings or API calls expecting a literal adapter name.
if hasattr(model, "peft_config"): | ||
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) | ||
restore_from_modelopt_state(model, modelopt_state) | ||
print_rank_0("Restored modelopt state") | ||
|
||
# Restore modelopt quantizer state dict | ||
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) | ||
if modelopt_weights is not None: | ||
set_quantizer_state_dict(model, modelopt_weights) | ||
print_rank_0("Restored modelopt quantizer state dict") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify the state restoration order matches the trainer's pattern.
The sequence here differs from the trainer's _restore_modelopt_state_with_weights
method (transformers_trainer.py lines 184-190):
Trainer pattern:
- Load state dict
- Pop
modelopt_state_weights
- Call
restore_from_modelopt_state
- Set quantizer state dict
Current export.py pattern:
- Load state dict
- Call
restore_from_modelopt_state
(line 55) - Pop
modelopt_state_weights
(line 59) - Set quantizer state dict
The pop operation at line 59 occurs after restore_from_modelopt_state
, which may cause the method to receive and process the weights that should be handled separately. Align the order with the trainer's proven pattern.
Apply this diff:
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+ modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
- modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if hasattr(model, "peft_config"): | |
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) | |
restore_from_modelopt_state(model, modelopt_state) | |
print_rank_0("Restored modelopt state") | |
# Restore modelopt quantizer state dict | |
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) | |
if modelopt_weights is not None: | |
set_quantizer_state_dict(model, modelopt_weights) | |
print_rank_0("Restored modelopt quantizer state dict") | |
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this | |
if hasattr(model, "peft_config"): | |
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) | |
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) | |
restore_from_modelopt_state(model, modelopt_state) | |
print_rank_0("Restored modelopt state") | |
# Restore modelopt quantizer state dict | |
if modelopt_weights is not None: | |
set_quantizer_state_dict(model, modelopt_weights) | |
print_rank_0("Restored modelopt quantizer state dict") |
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 53 to 62, the restore order differs
from the trainer: pop modelopt_state_weights before calling
restore_from_modelopt_state so the weights are removed from the state passed
into restore; specifically, after loading modelopt_state, call
modelopt_state.pop("modelopt_state_weights", None) and keep the popped value in
modelopt_weights, then call restore_from_modelopt_state(model, modelopt_state),
and finally if modelopt_weights is not None call set_quantizer_state_dict(model,
modelopt_weights) and print the restored messages.
def _load_best_model(self, *args, **kwargs): | ||
"""Load the best model for final evaluation.""" | ||
is_lora = getattr(self.args, "lora", None) | ||
if is_lora and not self.is_fsdp_enabled: | ||
# Custom logic for loading best model with LoRA | ||
# TODO: Remove once we migrate to using get_peft_model() | ||
adapter_name = self.model.active_adapter() | ||
self.model.delete_adapter(adapter_name) | ||
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) | ||
else: | ||
super()._load_best_model(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and validate checkpoint existence.
The custom LoRA loading logic lacks safety checks:
- No validation that
self.state.best_model_checkpoint
exists before attempting to load - No error handling if
load_adapter
fails - The condition checks
is_lora
andnot self.is_fsdp_enabled
, but an earlier comment thread mentioned this should only execute "if compress is enabled and fsdp2 is not enabled." The current logic doesn't check the compress flag.
Add defensive checks:
def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
- if is_lora and not self.is_fsdp_enabled:
+ is_compressed = getattr(self.quant_args, "compress", False)
+ if is_lora and not self.is_fsdp_enabled and is_compressed:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
+ if not self.state.best_model_checkpoint:
+ print_rank_0("No best model checkpoint found, skipping adapter reload")
+ return
+ try:
- adapter_name = self.model.active_adapter()
- self.model.delete_adapter(adapter_name)
- self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+ adapter_name = self.model.active_adapter()
+ self.model.delete_adapter(adapter_name)
+ self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+ print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
+ except Exception as e:
+ print_rank_0(f"Failed to load best adapter: {e}")
+ raise
else:
super()._load_best_model(*args, **kwargs)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _load_best_model(self, *args, **kwargs): | |
"""Load the best model for final evaluation.""" | |
is_lora = getattr(self.args, "lora", None) | |
if is_lora and not self.is_fsdp_enabled: | |
# Custom logic for loading best model with LoRA | |
# TODO: Remove once we migrate to using get_peft_model() | |
adapter_name = self.model.active_adapter() | |
self.model.delete_adapter(adapter_name) | |
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) | |
else: | |
super()._load_best_model(*args, **kwargs) | |
def _load_best_model(self, *args, **kwargs): | |
"""Load the best model for final evaluation.""" | |
is_lora = getattr(self.args, "lora", None) | |
is_compressed = getattr(self.quant_args, "compress", False) | |
if is_lora and not self.is_fsdp_enabled and is_compressed: | |
# Custom logic for loading best model with LoRA | |
# TODO: Remove once we migrate to using get_peft_model() | |
if not self.state.best_model_checkpoint: | |
print_rank_0("No best model checkpoint found, skipping adapter reload") | |
return | |
try: | |
adapter_name = self.model.active_adapter() | |
self.model.delete_adapter(adapter_name) | |
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) | |
print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}") | |
except Exception as e: | |
print_rank_0(f"Failed to load best adapter: {e}") | |
raise | |
else: | |
super()._load_best_model(*args, **kwargs) |
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
7e614ad
to
f5f91ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
280-290
: Add error handling and validate checkpoint existence.A past review comment noted that this custom LoRA loading logic lacks safety checks:
- No validation that
self.state.best_model_checkpoint
exists before attempting to load- No error handling if
load_adapter
fails- The condition doesn't check the compress flag, though the TODO mentions this should execute "if compress is enabled and fsdp2 is not enabled"
Apply this diff to add defensive checks:
def _load_best_model(self, *args, **kwargs): """Load the best model for final evaluation.""" is_lora = getattr(self.args, "lora", None) - if is_lora and not self.is_fsdp_enabled: + is_compressed = getattr(self.quant_args, "compress", False) + if is_lora and not self.is_fsdp_enabled and is_compressed: # Custom logic for loading best model with LoRA # TODO: Remove once we migrate to using get_peft_model() + if not self.state.best_model_checkpoint: + print_rank_0("No best model checkpoint found, skipping adapter reload") + return + try: adapter_name = self.model.active_adapter() self.model.delete_adapter(adapter_name) self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) + print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}") + except Exception as e: + print_rank_0(f"Failed to load best adapter: {e}") + raise else: super()._load_best_model(*args, **kwargs)examples/llm_qat/export.py (1)
53-62
: Fix the modelopt state restore order to match the trainer pattern.A past review comment noted that the restore order differs from the trainer's
_restore_modelopt_state_with_weights
method. The current implementation popsmodelopt_state_weights
AFTER callingrestore_from_modelopt_state
, which may cause the method to receive and process weights that should be handled separately.Trainer pattern (transformers_trainer.py lines 184-190):
- Load state dict
- Pop
modelopt_state_weights
- Call
restore_from_modelopt_state
- Set quantizer state dict
Current pattern:
- Load state dict (line 54)
- Call
restore_from_modelopt_state
(line 55)- Pop
modelopt_state_weights
(line 59)- Set quantizer state dict (line 61)
Apply this diff to align with the trainer's proven pattern:
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this if hasattr(model, "peft_config"): modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) restore_from_modelopt_state(model, modelopt_state) print_rank_0("Restored modelopt state") # Restore modelopt quantizer state dict - modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) if modelopt_weights is not None: set_quantizer_state_dict(model, modelopt_weights) print_rank_0("Restored modelopt quantizer state dict")Based on learnings.
modelopt/torch/export/quant_utils.py (1)
910-914
: Refine LoRA key matching to avoid false positives.The substring check
"lora" in key
can match unintended keys containing "lora" as part of a longer word (e.g., "flora", "colorama", "explorer_adapter"). A past review comment raised this concern and suggested checking against split key parts.Apply this diff to use more precise matching:
# remove LoRA adapters from state dict if is_modelopt_qlora: - for key in post_state_dict: - if "lora" in key and key not in keys_to_delete: - keys_to_delete.append(key) + for key in list(post_state_dict.keys()): + parts = key.split(".") + if (any(part == "lora" or part.startswith("lora_") for part in parts)) and key not in keys_to_delete: + keys_to_delete.append(key)
🧹 Nitpick comments (1)
examples/llm_qat/export.py (1)
37-50
: Add validation for checkpoint path and required files.The
get_lora_model
function doesn't validate that the checkpoint path exists or contains the requiredmodelopt_state_train.pth
file before attempting to load, which could lead to confusing error messages.Apply this diff to add defensive checks:
def get_lora_model( ckpt_path: str, device="cuda", ): """ Loads a QLoRA model that has been trained using modelopt trainer. """ + # Validate checkpoint path + ckpt_path_obj = Path(ckpt_path) + if not ckpt_path_obj.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}") + # TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment device_map = "auto" if device == "cpu": device_map = "cpu" # Load model model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this if hasattr(model, "peft_config"): + modelopt_state_path = ckpt_path_obj / "modelopt_state_train.pth" + if not modelopt_state_path.exists(): + raise FileNotFoundError( + f"Expected modelopt state file not found: {modelopt_state_path}. " + "This file should be created during training." + ) - modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) + modelopt_state = torch.load(modelopt_state_path, weights_only=False)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/llm_qat/README.md
(2 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)tests/examples/llm_qat/test_llm_qat.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/export/unified_export_hf.py
- examples/llm_qat/README.md
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
examples/llm_qat/export.py
🧬 Code graph analysis (2)
examples/llm_qat/export.py (6)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(340-499)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/export/unified_export_megatron.py (1)
state_dict
(465-469)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
maxbound
(193-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/export/quant_utils.py (3)
818-823
: LGTM: Parameter addition and docstring are well-documented.The new
is_modelopt_qlora
parameter is properly documented and has a sensible default value, maintaining backward compatibility.Also applies to: 830-830
842-859
: LGTM: Dynamic key filtering logic correctly handles modelopt QLora.The conditional extension of replacements and skip_keys for modelopt QLora is cleanly implemented, and the updated skip condition using
all()
properly handles the dynamic skip list.
1093-1099
: NVFP4 skip logic is correct
block_size
can only be zero whenget_weight_block_size
returns 0 (no quantizer, disabled, or empty blocks), andget_quantization_format
yields NVFP4, NVFP4_AWQ, or W4A8_NVFP4_FP8 only for enabled quantizers with non‐zero block sizes. No changes needed.modelopt/torch/quantization/plugins/transformers_trainer.py (2)
212-213
: LGTM: ModelOpt state saved after calibration.Positioning the save call after quantization and before compression ensures the calibration state is captured correctly, addressing the concern from past reviews.
149-149
: LGTM: Adapter name parameter removed.Removing the explicit
adapter_name
parameter simplifies the API and suggests the adapter name is now handled automatically or uses a sensible default.Also applies to: 354-354
examples/llm_qat/export.py (1)
67-115
: LGTM: Export logic correctly handles QLora and standard models.The main function properly detects QLora models, creates the appropriate directory structure (separate base_model directory for QLora), calls the export with the correct flag, and handles both model types in the save logic.
What does this PR do?
Type of change: New example
Overview: This PR provides an e2e example for fine-tuning a model using QLoRA with DDP and exporting checkpoint for deployment using vllm.
Usage
Refer to README.md changes
Testing
Trainer
./launch.sh --model meta-llama/Meta-Llama-3-8B --num_epochs 0.01 --lr 1e-3 --do_train True --output_dir test --quant_cfg FP8_DEFAULT_CFG --compress True --lora True
Export
python export.py --pyt_ckpt_path test --export_dir test-fp8
Deployment
vllm serve test-fp8/base_model --enable-lora --lora-modules sql-lora=test-fp8 --port 8090 --tokenizer test-fp8
e2e unit test
Sanity check weights, dtypes of generated checkpoint
Test phi4
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests