|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates. |
2 | 2 | import inspect |
3 | 3 | from contextlib import contextmanager |
| 4 | +from typing import Optional |
4 | 5 |
|
5 | 6 | import transformers |
6 | 7 | from packaging import version |
7 | 8 | from torch.utils.data import DataLoader |
8 | | -from transformers import PreTrainedModel |
| 9 | +from transformers import PreTrainedModel, Trainer |
9 | 10 | from trl import PPOTrainer as HFPPOTrainer |
10 | 11 |
|
11 | 12 | from swift.utils import patch_getattr |
@@ -63,3 +64,21 @@ def _save_checkpoint(self, *args, **kwargs): |
63 | 64 | trial = kwargs.get('trial') |
64 | 65 | self._determine_best_metric(metrics=metrics, trial=trial) |
65 | 66 | return super()._save_checkpoint(*args, **kwargs) |
| 67 | + |
| 68 | + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): |
| 69 | + # https://github.com/huggingface/trl/issues/2122 |
| 70 | + backup_model = self.model |
| 71 | + self.model = self.model.policy # save only the policy |
| 72 | + |
| 73 | + Trainer.save_model(self, output_dir, _internal_call) |
| 74 | + |
| 75 | + self.model = backup_model |
| 76 | + |
| 77 | + def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| 78 | + if self.is_deepspeed_enabled: |
| 79 | + state_dict = { |
| 80 | + name.removeprefix('policy.'): param |
| 81 | + for name, param in state_dict.items() if name.startswith('policy.') |
| 82 | + } |
| 83 | + |
| 84 | + super()._save(output_dir, state_dict) |
0 commit comments