|
21 | 21 | from dataclasses import dataclass, field
|
22 | 22 |
|
23 | 23 | import torch
|
| 24 | +from safetensors.torch import save_file |
24 | 25 | from tqdm import tqdm
|
25 | 26 |
|
26 | 27 | import modelopt.torch.opt as mto
|
27 | 28 | import modelopt.torch.quantization as mtq
|
28 | 29 | from modelopt.torch.distill import KDLossConfig
|
29 | 30 | from modelopt.torch.distill.mode import _convert_for_kd
|
30 | 31 | from modelopt.torch.distill.plugins.huggingface import KDTrainer
|
31 |
| -from modelopt.torch.export.unified_export_hf import export_hf_checkpoint |
32 | 32 | from modelopt.torch.opt.conversion import restore_from_modelopt_state
|
33 | 33 | from modelopt.torch.opt.plugins import ModelOptHFTrainer
|
34 | 34 | from modelopt.torch.quantization.config import QuantizeConfig
|
@@ -182,6 +182,18 @@ def _save_modelopt_state_with_weights(self):
|
182 | 182 |
|
183 | 183 | print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}")
|
184 | 184 |
|
| 185 | + # Save base model compressed weights for QLoRA |
| 186 | + if getattr(self.quant_args, "compress", False): |
| 187 | + # Save base model config.json |
| 188 | + self.model.config.save_pretrained(self.args.output_dir) |
| 189 | + |
| 190 | + # Save base model compressed weights excluding lora weights |
| 191 | + state_dict = self.model.state_dict() |
| 192 | + for k in [key for key in state_dict if "lora" in key]: |
| 193 | + del state_dict[k] |
| 194 | + |
| 195 | + save_file(state_dict, f"{self.args.output_dir}/model.safetensors") |
| 196 | + |
185 | 197 | def _restore_modelopt_state_with_weights(self):
|
186 | 198 | modelopt_state = torch.load(self._modelopt_state_path, weights_only=False)
|
187 | 199 | modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
|
@@ -288,12 +300,6 @@ def _load_best_model(self, *args, **kwargs):
|
288 | 300 | self.model.delete_adapter(adapter_name)
|
289 | 301 | self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
|
290 | 302 |
|
291 |
| - def export_base_model(self): |
292 |
| - """Export the basemodel to HF checkpoint for deployment.""" |
293 |
| - # Save config.json |
294 |
| - if self.accelerator.is_main_process: |
295 |
| - export_hf_checkpoint(self.model, export_dir=f"{self.args.output_dir}/base_model") |
296 |
| - |
297 | 303 | def _patch_accelerate_for_fsdp2_fix(self):
|
298 | 304 | """Fixes for accelerate prepare.
|
299 | 305 |
|
|
0 commit comments