-
Notifications
You must be signed in to change notification settings - Fork 169
QLoRA DDP export #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
QLoRA DDP export #353
Changes from all commits
2304e4d
c64f6d9
72ddf29
c5c6a57
994a3a0
bf2a0f9
70200f1
3cf89cf
5fcef97
7310346
40202eb
f1a2ff3
3b7ef44
0b19255
f484b5d
32e6330
f9e8bf6
fbc0278
0b89f8a
bc6c835
bb2d6ef
b81b4de
f5f91ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,135 @@ | ||||||||||||||||||||||||||||||||||||||||||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
import argparse | ||||||||||||||||||||||||||||||||||||||||||||
import json | ||||||||||||||||||||||||||||||||||||||||||||
import warnings | ||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
import modelopt.torch.opt as mto | ||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format | ||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint | ||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.opt.conversion import restore_from_modelopt_state | ||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.quantization.utils import set_quantizer_state_dict | ||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.utils import print_rank_0 | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
RAND_SEED = 1234 | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Enable automatic save/load of modelopt state huggingface checkpointing | ||||||||||||||||||||||||||||||||||||||||||||
mto.enable_huggingface_checkpointing() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def get_lora_model( | ||||||||||||||||||||||||||||||||||||||||||||
ckpt_path: str, | ||||||||||||||||||||||||||||||||||||||||||||
device="cuda", | ||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||
Loads a QLoRA model that has been trained using modelopt trainer. | ||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||
# TODO: Add support for merging adapters in BF16 and merging adapters with quantization for deployment | ||||||||||||||||||||||||||||||||||||||||||||
device_map = "auto" | ||||||||||||||||||||||||||||||||||||||||||||
if device == "cpu": | ||||||||||||||||||||||||||||||||||||||||||||
device_map = "cpu" | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Load model | ||||||||||||||||||||||||||||||||||||||||||||
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this | ||||||||||||||||||||||||||||||||||||||||||||
if hasattr(model, "peft_config"): | ||||||||||||||||||||||||||||||||||||||||||||
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) | ||||||||||||||||||||||||||||||||||||||||||||
restore_from_modelopt_state(model, modelopt_state) | ||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Restored modelopt state") | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Restore modelopt quantizer state dict | ||||||||||||||||||||||||||||||||||||||||||||
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) | ||||||||||||||||||||||||||||||||||||||||||||
if modelopt_weights is not None: | ||||||||||||||||||||||||||||||||||||||||||||
set_quantizer_state_dict(model, modelopt_weights) | ||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Restored modelopt quantizer state dict") | ||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+53
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify the state restoration order matches the trainer's pattern. The sequence here differs from the trainer's Trainer pattern:
Current export.py pattern:
The pop operation at line 59 occurs after Apply this diff: # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
+ modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
# Restore modelopt quantizer state dict
- modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
if modelopt_weights is not None:
set_quantizer_state_dict(model, modelopt_weights)
print_rank_0("Restored modelopt quantizer state dict") 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
return model | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
def main(args): | ||||||||||||||||||||||||||||||||||||||||||||
# Load model | ||||||||||||||||||||||||||||||||||||||||||||
model = get_lora_model(args.pyt_ckpt_path, args.device) | ||||||||||||||||||||||||||||||||||||||||||||
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) | ||||||||||||||||||||||||||||||||||||||||||||
is_qlora = hasattr(model, "peft_config") | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Export HF checkpoint | ||||||||||||||||||||||||||||||||||||||||||||
export_dir = Path(args.export_path) | ||||||||||||||||||||||||||||||||||||||||||||
export_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||
if is_qlora: | ||||||||||||||||||||||||||||||||||||||||||||
base_model_dir = export_dir / "base_model" | ||||||||||||||||||||||||||||||||||||||||||||
base_model_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||
base_model_dir = export_dir | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file: | ||||||||||||||||||||||||||||||||||||||||||||
json.dump(hf_quant_config, file, indent=4) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
hf_quant_config = convert_hf_quant_config_format(hf_quant_config) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Save model | ||||||||||||||||||||||||||||||||||||||||||||
if is_qlora: | ||||||||||||||||||||||||||||||||||||||||||||
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict) | ||||||||||||||||||||||||||||||||||||||||||||
model.save_pretrained(export_dir) | ||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||
model.save_pretrained(export_dir, state_dict=post_state_dict) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
config_path = f"{base_model_dir}/config.json" | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
config_data = model.config.to_dict() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
config_data["quantization_config"] = hf_quant_config | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
with open(config_path, "w") as file: | ||||||||||||||||||||||||||||||||||||||||||||
json.dump(config_data, file, indent=4) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Save tokenizer | ||||||||||||||||||||||||||||||||||||||||||||
tokenizer.save_pretrained(export_dir) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||
warnings.warn( | ||||||||||||||||||||||||||||||||||||||||||||
"Cannot export model to the model_config. The modelopt-optimized model state_dict" | ||||||||||||||||||||||||||||||||||||||||||||
" can be saved with torch.save for further inspection." | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
raise e | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||
parser = argparse.ArgumentParser(description=__doc__) | ||||||||||||||||||||||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||
"--pyt_ckpt_path", | ||||||||||||||||||||||||||||||||||||||||||||
help="Specify where the PyTorch checkpoint path is", | ||||||||||||||||||||||||||||||||||||||||||||
required=True, | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--device", default="cuda") | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||
"--export_path", | ||||||||||||||||||||||||||||||||||||||||||||
default="exported_model", | ||||||||||||||||||||||||||||||||||||||||||||
help="Path to save the exported model", | ||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -815,13 +815,19 @@ def from_quantized_weight( | |
raise NotImplementedError(f"quantization format {quantization} not supported") | ||
|
||
|
||
def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict: | ||
def postprocess_state_dict( | ||
state_dict: dict, | ||
maxbound: float, | ||
quantization: str | None, | ||
is_modelopt_qlora: bool = False, | ||
) -> dict: | ||
sugunav14 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Filters out keys related to weight quantizers and updates KV cache related keys. | ||
Args: | ||
state_dict: The full model state_dict. | ||
maxbound: The maximum bound value for the output quantizer. | ||
quantization: The KV cache quantization format. | ||
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. | ||
Returns: | ||
The filtered state_dict without unnecessary keys like '_amax' and non KV cache output quantizers. | ||
|
@@ -833,17 +839,24 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | |
"v_bmm_quantizer._bias_value": "v_proj.v_bias", | ||
"input_quantizer._pre_quant_scale": "pre_quant_scale", | ||
} | ||
skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"] | ||
|
||
# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment | ||
if is_modelopt_qlora: | ||
replacements.update( | ||
{ | ||
"base_layer.weight": "weight", | ||
"base_layer.input_scale": "input_scale", | ||
"base_layer.weight_scale": "weight_scale", | ||
} | ||
) | ||
skip_keys.append("base_layer") | ||
|
||
post_state_dict = {} | ||
|
||
for key, value in state_dict.items(): | ||
# Skip keys not related to quantizers | ||
if ( | ||
"output_quantizer" not in key | ||
and "_amax" not in key | ||
and "_bias_value" not in key | ||
and "input_quantizer._pre_quant_scale" not in key | ||
): | ||
if all(skip_key not in key for skip_key in skip_keys): | ||
post_state_dict[key] = value | ||
continue | ||
|
||
|
@@ -894,6 +907,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | |
): | ||
keys_to_delete.append(key) | ||
|
||
# remove LoRA adapters from state dict | ||
if is_modelopt_qlora: | ||
for key in post_state_dict: | ||
if "lora" in key and key not in keys_to_delete: | ||
keys_to_delete.append(key) | ||
# Check for tied weights and remove duplicates | ||
seen_tensors = {} | ||
|
||
|
@@ -1072,6 +1090,14 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st | |
if block_size == 0: | ||
block_size = get_weight_block_size(module) | ||
|
||
# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled | ||
if block_size == 0 and quantization_format in [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have a better flag instead of checking the block_size? E.g. weight_quantizer enabled vs disabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me check and update the PR! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
QUANTIZATION_NVFP4, | ||
QUANTIZATION_NVFP4_AWQ, | ||
QUANTIZATION_W4A8_NVFP4_FP8, | ||
]: | ||
continue | ||
|
||
# Construct per layer config dictionary | ||
layer_config_dict[name + ".quantization"] = quantization_format | ||
layer_config_dict[name + ".awq_block_size"] = block_size | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -146,7 +146,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model, "peft_config" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO: use get_peft_model here instead of add_adapter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.add_adapter(self.args.lora_config, adapter_name="adapter") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.add_adapter(self.args.lora_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Lora adapter added.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hasattr(self.model, "peft_config") and self.quant_cfg is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -209,14 +209,16 @@ def forward_loop(model): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Quantizing the model...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Save modelopt state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._save_modelopt_state_with_weights() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if getattr(self.quant_args, "compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Compressing model after calibration") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mtq.compress(self.model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Force garbage collection to free up memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._save_modelopt_state_with_weights() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.accelerator.is_main_process: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -275,6 +277,18 @@ def save_model(self, *args, **kwargs): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
outputs = super().save_model(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _load_best_model(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Load the best model for final evaluation.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
is_lora = getattr(self.args, "lora", None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if is_lora and not self.is_fsdp_enabled: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Custom logic for loading best model with LoRA | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO: Remove once we migrate to using get_peft_model() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
adapter_name = self.model.active_adapter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.delete_adapter(adapter_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq, does this only load the lora adapter? Maybe we need to add a check to make sure the base model is expected to be frozen/compressed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this loads only the best lora adapter. The reason I introduced this logic is our current workflow seems slightly incompatible with HF trainer as we are using .add_adapter() instead of get_peft_model() due to which HF trainer doesn't detect it as a peft model. This is causing some errors in the final load_best_checkpoint() call so I added this temporary fix until we migrate to using get_peft_model(). Currently I execute this logic only if compress is enabled and fsdp2 is not enabled (indicating DDP QLoRA). Do you recommend any other checks? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super()._load_best_model(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+280
to
+290
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and validate checkpoint existence. The custom LoRA loading logic lacks safety checks:
Add defensive checks: def _load_best_model(self, *args, **kwargs):
"""Load the best model for final evaluation."""
is_lora = getattr(self.args, "lora", None)
- if is_lora and not self.is_fsdp_enabled:
+ is_compressed = getattr(self.quant_args, "compress", False)
+ if is_lora and not self.is_fsdp_enabled and is_compressed:
# Custom logic for loading best model with LoRA
# TODO: Remove once we migrate to using get_peft_model()
+ if not self.state.best_model_checkpoint:
+ print_rank_0("No best model checkpoint found, skipping adapter reload")
+ return
+ try:
- adapter_name = self.model.active_adapter()
- self.model.delete_adapter(adapter_name)
- self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+ adapter_name = self.model.active_adapter()
+ self.model.delete_adapter(adapter_name)
+ self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
+ print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}")
+ except Exception as e:
+ print_rank_0(f"Failed to load best adapter: {e}")
+ raise
else:
super()._load_best_model(*args, **kwargs) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _patch_accelerate_for_fsdp2_fix(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Fixes for accelerate prepare. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -337,7 +351,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.quant_cfg is not None and not is_quantized(self.model): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._quantize_model() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if getattr(self.args, "lora_config", None) is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.add_adapter(self.args.lora_config, adapter_name="adapter") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.model.add_adapter(self.args.lora_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print_rank_0("Lora adapter added.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._convert_to_distillation_model() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.