Skip to content

Commit afba3a9

Browse files
committed
Update trainer to save base model weights and config.json
Signed-off-by: Suguna Velury <[email protected]>
1 parent 7b7188e commit afba3a9

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

examples/llm_qat/main.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,6 @@ def train():
273273
kwargs = {"export_student": True} if training_args.distill else {}
274274
trainer.save_model(training_args.output_dir, **kwargs)
275275

276-
if training_args.lora and getattr(quant_args, "compress", False):
277-
trainer.export_base_model()
278-
279276

280277
if __name__ == "__main__":
281278
train()

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from dataclasses import dataclass, field
2222

2323
import torch
24+
from safetensors.torch import save_file
2425
from tqdm import tqdm
2526

2627
import modelopt.torch.opt as mto
2728
import modelopt.torch.quantization as mtq
2829
from modelopt.torch.distill import KDLossConfig
2930
from modelopt.torch.distill.mode import _convert_for_kd
3031
from modelopt.torch.distill.plugins.huggingface import KDTrainer
31-
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
3232
from modelopt.torch.opt.conversion import restore_from_modelopt_state
3333
from modelopt.torch.opt.plugins import ModelOptHFTrainer
3434
from modelopt.torch.quantization.config import QuantizeConfig
@@ -182,6 +182,18 @@ def _save_modelopt_state_with_weights(self):
182182

183183
print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}")
184184

185+
# Save base model compressed weights for QLoRA
186+
if getattr(self.quant_args, "compress", False):
187+
# Save base model config.json
188+
self.model.config.save_pretrained(self.args.output_dir)
189+
190+
# Save base model compressed weights excluding lora weights
191+
state_dict = self.model.state_dict()
192+
for k in [key for key in state_dict if "lora" in key]:
193+
del state_dict[k]
194+
195+
save_file(state_dict, f"{self.args.output_dir}/model.safetensors")
196+
185197
def _restore_modelopt_state_with_weights(self):
186198
modelopt_state = torch.load(self._modelopt_state_path, weights_only=False)
187199
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
@@ -288,12 +300,6 @@ def _load_best_model(self, *args, **kwargs):
288300
self.model.delete_adapter(adapter_name)
289301
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
290302

291-
def export_base_model(self):
292-
"""Export the basemodel to HF checkpoint for deployment."""
293-
# Save config.json
294-
if self.accelerator.is_main_process:
295-
export_hf_checkpoint(self.model, export_dir=f"{self.args.output_dir}/base_model")
296-
297303
def _patch_accelerate_for_fsdp2_fix(self):
298304
"""Fixes for accelerate prepare.
299305

0 commit comments

Comments
 (0)