Skip to content

Commit b81b4de

Browse files
committed
Refactor to include QAT/QAD export too
Signed-off-by: Suguna Velury <[email protected]>
1 parent bb2d6ef commit b81b4de

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

examples/llm_qat/export.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from transformers import AutoModelForCausalLM, AutoTokenizer
2323

24+
import modelopt.torch.opt as mto
2425
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
2526
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
2627
from modelopt.torch.opt.conversion import restore_from_modelopt_state
@@ -29,6 +30,9 @@
2930

3031
RAND_SEED = 1234
3132

33+
# Enable automatic save/load of modelopt state huggingface checkpointing
34+
mto.enable_huggingface_checkpointing()
35+
3236

3337
def get_lora_model(
3438
ckpt_path: str,
@@ -42,19 +46,20 @@ def get_lora_model(
4246
if device == "cpu":
4347
device_map = "cpu"
4448

45-
# Load model with adapters
49+
# Load model
4650
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
4751

48-
# Restore modelopt state
49-
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calib.pth", weights_only=False)
50-
restore_from_modelopt_state(model, modelopt_state)
51-
print_rank_0("Restored modelopt state")
52+
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
53+
if hasattr(model, "peft_config"):
54+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
55+
restore_from_modelopt_state(model, modelopt_state)
56+
print_rank_0("Restored modelopt state")
5257

53-
# Restore modelopt quantizer state dict
54-
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
55-
if modelopt_weights is not None:
56-
set_quantizer_state_dict(model, modelopt_weights)
57-
print_rank_0("Restored modelopt quantizer state dict")
58+
# Restore modelopt quantizer state dict
59+
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
60+
if modelopt_weights is not None:
61+
set_quantizer_state_dict(model, modelopt_weights)
62+
print_rank_0("Restored modelopt quantizer state dict")
5863

5964
return model
6065

@@ -63,25 +68,31 @@ def main(args):
6368
# Load model
6469
model = get_lora_model(args.pyt_ckpt_path, args.device)
6570
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
71+
is_qlora = hasattr(model, "peft_config")
6672

6773
# Export HF checkpoint
6874
export_dir = Path(args.export_path)
6975
export_dir.mkdir(parents=True, exist_ok=True)
70-
base_model_dir = export_dir / "base_model"
71-
base_model_dir.mkdir(parents=True, exist_ok=True)
76+
if is_qlora:
77+
base_model_dir = export_dir / "base_model"
78+
base_model_dir.mkdir(parents=True, exist_ok=True)
79+
else:
80+
base_model_dir = export_dir
7281

7382
try:
74-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True)
83+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)
7584

7685
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
7786
json.dump(hf_quant_config, file, indent=4)
7887

7988
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
8089

81-
# Save base model
82-
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
83-
# Save adapters
84-
model.save_pretrained(export_dir)
90+
# Save model
91+
if is_qlora:
92+
model.base_model.save_pretrained(f"{base_model_dir}", state_dict=post_state_dict)
93+
model.save_pretrained(export_dir)
94+
else:
95+
model.save_pretrained(export_dir, state_dict=post_state_dict)
8596

8697
config_path = f"{base_model_dir}/config.json"
8798

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,8 @@ def forward_loop(model):
209209
print_rank_0("Quantizing the model...")
210210
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]
211211

212-
# Save modelopt state before compression. This is used to later export the model for deployment.
213-
modelopt_state = mto.modelopt_state(self.model)
214-
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
215-
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth")
216-
217-
print_rank_0(
218-
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}"
219-
)
212+
# Save modelopt state
213+
self._save_modelopt_state_with_weights()
220214

221215
if getattr(self.quant_args, "compress", False):
222216
print_rank_0("Compressing model after calibration")
@@ -225,7 +219,6 @@ def forward_loop(model):
225219
# Force garbage collection to free up memory
226220
gc.collect()
227221

228-
self._save_modelopt_state_with_weights()
229222
torch.cuda.empty_cache()
230223

231224
if self.accelerator.is_main_process:

0 commit comments

Comments
 (0)