|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates. |
2 | | - |
| 2 | +# Part of the implementation is borrowed from huggingface/transformers. |
3 | 3 | import os |
4 | 4 | import shutil |
5 | 5 | from types import MethodType |
6 | 6 | from typing import Callable, Dict, List, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import json |
| 9 | +import numpy as np |
9 | 10 | import safetensors |
10 | 11 | import torch |
11 | 12 | from datasets import Dataset as HfDataset |
|
15 | 16 | from transformers import PreTrainedModel, PreTrainedTokenizerBase |
16 | 17 | from transformers.data.data_collator import DataCollator |
17 | 18 | from transformers.modeling_utils import unwrap_model |
| 19 | +from transformers.trainer import PREFIX_CHECKPOINT_DIR, TRAINER_STATE_NAME |
18 | 20 | from transformers.trainer_callback import TrainerCallback |
19 | 21 | from transformers.trainer_utils import EvalPrediction, HubStrategy |
20 | 22 | from transformers.training_args import TrainingArguments |
@@ -278,3 +280,52 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): |
278 | 280 | if self.tokenizer is not None: |
279 | 281 | self.tokenizer.save_pretrained(output_dir) |
280 | 282 | torch.save(self.args, os.path.join(output_dir, 'training_args.bin')) |
| 283 | + |
| 284 | + def _save_checkpoint(self, model, trial, metrics=None): |
| 285 | + only_save_model = getattr(self.args, 'only_save_model', False) |
| 286 | + if only_save_model: |
| 287 | + return self._only_save_model(model, trial, metrics) |
| 288 | + else: |
| 289 | + return super()._save_checkpoint(model, trial, metrics) |
| 290 | + |
| 291 | + def _only_save_model(self, model, trial, metrics=None): |
| 292 | + # Save model checkpoint |
| 293 | + checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' |
| 294 | + |
| 295 | + if self.hp_search_backend is None and trial is None: |
| 296 | + self.store_flos() |
| 297 | + |
| 298 | + run_dir = self._get_output_dir(trial=trial) |
| 299 | + output_dir = os.path.join(run_dir, checkpoint_folder) |
| 300 | + self.save_model(output_dir, _internal_call=True) |
| 301 | + if self.is_deepspeed_enabled: |
| 302 | + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed |
| 303 | + # config `stage3_gather_16bit_weights_on_model_save` is True |
| 304 | + self.model_wrapped.save_checkpoint(output_dir) |
| 305 | + |
| 306 | + # Determine the new best metric / best model checkpoint |
| 307 | + if metrics is not None and self.args.metric_for_best_model is not None: |
| 308 | + metric_to_check = self.args.metric_for_best_model |
| 309 | + if not metric_to_check.startswith('eval_'): |
| 310 | + metric_to_check = f'eval_{metric_to_check}' |
| 311 | + metric_value = metrics[metric_to_check] |
| 312 | + |
| 313 | + operator = np.greater if self.args.greater_is_better else np.less |
| 314 | + if (self.state.best_metric is None |
| 315 | + or self.state.best_model_checkpoint is None |
| 316 | + or operator(metric_value, self.state.best_metric)): |
| 317 | + self.state.best_metric = metric_value |
| 318 | + self.state.best_model_checkpoint = output_dir |
| 319 | + |
| 320 | + # Save the Trainer state |
| 321 | + if self.args.should_save: |
| 322 | + self.state.save_to_json( |
| 323 | + os.path.join(output_dir, TRAINER_STATE_NAME)) |
| 324 | + |
| 325 | + # push to hub |
| 326 | + if self.args.push_to_hub: |
| 327 | + self._push_from_checkpoint(output_dir) |
| 328 | + |
| 329 | + # Maybe delete some older checkpoints. |
| 330 | + if self.args.should_save: |
| 331 | + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) |
0 commit comments